# Set-up

In [None]:
import os

script_dir = os.path.dirname(os.path.realpath('__file__'))
parent_dir = os.path.dirname(script_dir)

## Importing modules

In [None]:
# Standard library
import os
import sys
import random
import glob
import pickle
from pathlib import Path
import multiprocessing
import warnings

# Data manipulation
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rc_context

# Single-cell analysis
import scanpy as sc
import anndata
import harmonypy as hm
import espressopro as ep

# Machine learning - sklearn
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.model_selection import train_test_split, KFold
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_curve,
    auc,
    ConfusionMatrixDisplay,
)
from sklearn import preprocessing

# Additional ML utilities
from netcal.scaling import TemperatureScaling
from scipy.stats import skew, kurtosis, normaltest, shapiro
import joblib

# Configure multiprocessing
num_cores = multiprocessing.cpu_count() - 2
print(f"Total CPU cores to be used: {num_cores}")

# Configure cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Suppress warnings
warnings.filterwarnings('ignore')

In [None]:
import warnings
warnings.filterwarnings('ignore')

Loading custom scripts

In [None]:
import sys
sys.path.append(parent_dir + '/Scripts/SingleCellUtils')

import SCUtils

import sys
sys.path.append(parent_dir + '/Scripts/ModelTraining')

import MLTraining

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def assign_labels(dataset, reduction, n_neighbors, label_input, label_output, frequency_threshold):
    """Propagate consensus labels within high-resolution clusters."""
    sc.pp.neighbors(dataset, use_rep=reduction, n_neighbors=n_neighbors)
    sc.tl.leiden(dataset, key_added='clusters', resolution=10)
    dataset.obs[label_output] = dataset.obs[label_input]
    
    for cluster in dataset.obs['clusters'].unique():
        cluster_labels = dataset.obs.loc[dataset.obs['clusters'] == cluster, label_input]
        most_frequent_label = cluster_labels.mode()[0]
        frequency = (cluster_labels == most_frequent_label).mean()
        
        if frequency > frequency_threshold:
            dataset.obs.loc[dataset.obs['clusters'] == cluster, label_output] = most_frequent_label
    
    return dataset


def _norm_feats(names) -> pd.Index:
    """Normalize feature names: lowercase, replace spaces/underscores with hyphens."""
    s = pd.Index(map(str, names))
    return (s.str.strip()
             .str.lower()
             .str.replace(r"[ _/]+", "-", regex=True)
             .str.replace(r"-+", "-", regex=True)
             .str.strip("-"))


def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    """Add 'Celltype' column from AnnData.obs[field], reindexed to match df."""
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    
    if out["Celltype"].isna().any():
        print(f"[WARN] {out['Celltype'].isna().sum()} rows got NaN Celltype")
    
    return out


def _check_finite(df: pd.DataFrame, tag: str):
    """Raise error if DataFrame contains non-finite values."""
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values in {tag} at {bad}")


def _unwrap_estimator(m):
    """Extract base estimator from sklearn wrappers (e.g., CalibratedClassifierCV)."""
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m


def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    """Verify all models saw the expected number of features."""
    for name, est in [("NB", models_dict.get("NB")), ("XGB", models_dict.get("XGB")),
                       ("KNN", models_dict.get("KNN")), ("MLP", models_dict.get("MLP")),
                       ("Stacker", models_dict.get("Stacker"))]:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

In [None]:
pip list

PYTHONHASHSEED was set as envinronmental variable to 0 as follows:
    
    conda env config vars set PYTHONHASHSEED=0

In [None]:
os.environ['PYTHONHASHSEED'] = '0'
random.seed(42)
np.random.seed(42)

In [None]:
def ensure_pythonhashseed(seed=0):
    current_seed = os.environ.get("PYTHONHASHSEED")

    seed = str(seed)
    if current_seed is None or current_seed != seed:
        print(f'Setting PYTHONHASHSEED="{seed}"')
        os.environ["PYTHONHASHSEED"] = seed
        # restart the current process
        os.execl(sys.executable, sys.executable, *sys.argv)

In [None]:
import random

hash = random.getrandbits(128)

print("hash value: %032x" % hash)

## Defining data path

In [None]:
# Specify the folder path
data_path = parent_dir + "/Data"
figures_path = parent_dir + "/Figures/Model_Training"

if not os.path.exists(figures_path):
    os.makedirs(figures_path)

# Create the folder
os.makedirs(data_path + "/Pre_trained_models", exist_ok=True)

# Loading datasets

In [None]:
train_barcodes_path = data_path + "/Training_barcodes"
train_barcodes_path
test_barcodes_path = data_path + "/Testing_barcodes"
test_barcodes_path

## Loading Hao Y. et al. (2021) dataset

In [None]:
Hao_dataset_Train = sc.read_h5ad(data_path + "/References/Hao" + "/228AB_healthy_donors_PBMNCs_annotated_Train.h5ad")
Hao_dataset_Test = sc.read_h5ad(data_path + "/References/Hao" + "/228AB_healthy_donors_PBMNCs_annotated_Test.h5ad")
Hao_dataset_Cal = sc.read_h5ad(data_path + "/References/Hao" + "/228AB_healthy_donors_PBMNCs_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Hao_dataset_Train

In [None]:
Hao_dataset_Cal

In [None]:
# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_wnn.umap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'GdT': "#EDB416",
    'Mesenchymal': '#BBBBBB',
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'DnT': "#504423",
    'Treg': "#6E6C37",
    'Platelet': "#FF39A6",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Hao_dataset_Train.obs[label_key]):
    Hao_dataset_Train.obs[label_key] = Hao_dataset_Train.obs[label_key].astype('category')

cats = list(Hao_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Hao_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5.2, 4.5)}):
    sc.pl.embedding(
        Hao_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=False,
        title='Hao',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(figures_path + "/Hao_final_annotation.png", 
        dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
pop_labels = Hao_dataset_Train.obs['Consensus_annotation_detailed_final'].values
unique_pops = np.unique(pop_labels)
print(f"Found {len(unique_pops)} populations:", unique_pops)

In [None]:
ep.Normalise_protein_data(Hao_dataset_Train, inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Hao_dataset_Cal,   inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Hao_dataset_Test,  inplace=True, axis=1, flavor="seurat")

In [None]:
sc.tl.rank_genes_groups(Hao_dataset_Train, 'Consensus_annotation_detailed_final', method='wilcoxon')
sc.pl.rank_genes_groups(Hao_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.rank_genes_groups_matrixplot(Hao_dataset_Train)

In [None]:
Hao_dataset_Train

In [None]:
sc.pl.violin(Hao_dataset_Train, keys='CD56', groupby='celltype.l2', rotation=90, use_raw=False)

In [None]:
sc.pl.violin(Hao_dataset_Train, keys='CD56', groupby='Consensus_annotation_detailed_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Hao_data_Train = pd.DataFrame.sparse.from_spmatrix(Hao_dataset_Train.X, index=Hao_dataset_Train.obs_names, columns=Hao_dataset_Train.var_names)
Hao_data_Test = pd.DataFrame.sparse.from_spmatrix(Hao_dataset_Test.X, index=Hao_dataset_Test.obs_names, columns=Hao_dataset_Test.var_names)
Hao_data_Cal = pd.DataFrame.sparse.from_spmatrix(Hao_dataset_Cal.X, index=Hao_dataset_Cal.obs_names, columns=Hao_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgD","IgM", "Rag-IgG2c"]
Hao_data_Train = Hao_data_Train.drop(columns=columns_to_drop, errors='ignore')
Hao_data_Test = Hao_data_Test.drop(columns=columns_to_drop, errors='ignore')
Hao_data_Cal = Hao_data_Cal.drop(columns=columns_to_drop, errors='ignore')

## Loading Zhang X. et al. (2024) dataset

In [None]:
Zhang_dataset_Train = sc.read_h5ad(data_path + "/References/Zhang" + "/Zhang_adata_annotated_Train.h5ad")
Zhang_dataset_Test = sc.read_h5ad(data_path + "/References/Zhang" + "/Zhang_adata_annotated_Test.h5ad")
Zhang_dataset_Cal = sc.read_h5ad(data_path + "/References/Zhang" + "/Zhang_adata_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Zhang_dataset_Train

In [None]:
# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_umap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'GdT': "#EDB416",
    'Mesenchymal': '#BBBBBB',
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'DnT': "#504423",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Zhang_dataset_Train.obs[label_key]):
    Zhang_dataset_Train.obs[label_key] = Zhang_dataset_Train.obs[label_key].astype('category')

cats = list(Zhang_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Zhang_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5.5, 4.5)}):
    sc.pl.embedding(
        Zhang_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=False,
        title='Zhang',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(figures_path + "/Zhang_final_annotation.png", 
        dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
ep.Normalise_protein_data(Zhang_dataset_Train, inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Zhang_dataset_Cal,   inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Zhang_dataset_Test,  inplace=True, axis=1, flavor="seurat")

In [None]:
sc.tl.rank_genes_groups(Zhang_dataset_Train, 'Consensus_annotation_detailed_final', method='wilcoxon')
sc.pl.rank_genes_groups(Zhang_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.violin(Zhang_dataset_Train, keys='CD123', groupby='Consensus_annotation_broad_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Zhang_data_Train = pd.DataFrame.sparse.from_spmatrix(Zhang_dataset_Train.X, index=Zhang_dataset_Train.obs_names, columns=Zhang_dataset_Train.var_names)
Zhang_data_Test = pd.DataFrame.sparse.from_spmatrix(Zhang_dataset_Test.X, index=Zhang_dataset_Test.obs_names, columns=Zhang_dataset_Test.var_names)
Zhang_data_Cal = pd.DataFrame.sparse.from_spmatrix(Zhang_dataset_Cal.X, index=Zhang_dataset_Cal.obs_names, columns=Zhang_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgG.Fc", "Isotype_G0114F7", "Isotype_HTK888",
                   "Isotype_MOPC.173", "Isotype_MOPC.21", "Isotype_MPC.11",
                   "Isotype_RTK2071", "Isotype_RTK2758", "Isotype_RTK4174",
                   "Isotype_RTK4530"]
Zhang_data_Train = Zhang_data_Train.drop(columns=columns_to_drop, errors='ignore')
Zhang_data_Test = Zhang_data_Test.drop(columns=columns_to_drop, errors='ignore')
Zhang_data_Cal = Zhang_data_Cal.drop(columns=columns_to_drop, errors='ignore')

## Loading Triana S. et al. (2021) dataset

In [None]:
Triana_dataset_Train = sc.read_h5ad(data_path + "/References/Triana" + "/97AB_young_and_old_adult_healthy_donor_BMMNCs_annotated_Train.h5ad")
Triana_dataset_Test = sc.read_h5ad(data_path + "/References/Triana" + "/97AB_young_and_old_adult_healthy_donor_BMMNCs_annotated_Test.h5ad")
Triana_dataset_Cal = sc.read_h5ad(data_path + "/References/Triana" + "/97AB_young_and_old_adult_healthy_donor_BMMNCs_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Triana_dataset_Train

In [None]:
# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_mofaumap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'GdT': "#EDB416",
    'Mesenchymal': '#BBBBBB',
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'DnT': "#504423",
    'Treg': "#6E6C37",
    'Platelet': "#FF39A6",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Triana_dataset_Train.obs[label_key]):
    Triana_dataset_Train.obs[label_key] = Triana_dataset_Train.obs[label_key].astype('category')

cats = list(Triana_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Triana_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5.25, 4.5)}):
    sc.pl.embedding(
        Triana_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=False,
        title='Triana',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(figures_path + "/Triana_final_annotation.png", 
        dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
ep.Normalise_protein_data(Triana_dataset_Train, inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Triana_dataset_Cal,   inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Triana_dataset_Test,  inplace=True, axis=1, flavor="seurat")

In [None]:
sc.tl.rank_genes_groups(Triana_dataset_Train, 'Consensus_annotation_simplified_final', method='wilcoxon')
sc.pl.rank_genes_groups(Triana_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.violin(Triana_dataset_Train, keys='CD133', groupby='Consensus_annotation_detailed_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Triana_data_Train = pd.DataFrame.sparse.from_spmatrix(Triana_dataset_Train.X, index=Triana_dataset_Train.obs_names, columns=Triana_dataset_Train.var_names)
Triana_data_Test = pd.DataFrame.sparse.from_spmatrix(Triana_dataset_Test.X, index=Triana_dataset_Test.obs_names, columns=Triana_dataset_Test.var_names)
Triana_data_Cal = pd.DataFrame.sparse.from_spmatrix(Triana_dataset_Cal.X, index=Triana_dataset_Cal.obs_names, columns=Triana_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgG", "IgD"]
Triana_data_Train = Triana_data_Train.drop(columns=columns_to_drop, errors='ignore')
Triana_data_Test = Triana_data_Test.drop(columns=columns_to_drop, errors='ignore')
Triana_data_Cal = Triana_data_Cal.drop(columns=columns_to_drop, errors='ignore')

## Loading Luecken M.D. et al. (2021) dataset

In [None]:
Luecken_dataset_Train = sc.read_h5ad(data_path + "/References/Luecken" + "/140AB_adult_healthy_donor_BMMNCs_annotated_Train.h5ad")
Luecken_dataset_Test = sc.read_h5ad(data_path + "/References/Luecken" + "/140AB_adult_healthy_donor_BMMNCs_annotated_Test.h5ad")
Luecken_dataset_Cal = sc.read_h5ad(data_path + "/References/Luecken" + "/140AB_adult_healthy_donor_BMMNCs_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Luecken_dataset_Train

In [None]:
# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_umap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'GdT': "#EDB416",
    'Mesenchymal': '#BBBBBB',
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'DnT': "#504423",
    'Treg': "#6E6C37",
    'Platelet': "#FF39A6",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Luecken_dataset_Train.obs[label_key]):
    Luecken_dataset_Train.obs[label_key] = Luecken_dataset_Train.obs[label_key].astype('category')

cats = list(Luecken_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Luecken_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5, 4.5)}):
    sc.pl.embedding(
        Luecken_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=False,
        title='Luecken',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.savefig(figures_path + "/Luecken_final_annotation.png", 
        dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
ep.Normalise_protein_data(Luecken_dataset_Train, inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Luecken_dataset_Cal,   inplace=True, axis=1, flavor="seurat")
ep.Normalise_protein_data(Luecken_dataset_Test,  inplace=True, axis=1, flavor="seurat")

In [None]:
sc.tl.rank_genes_groups(Luecken_dataset_Train, 'Consensus_annotation_detailed_final', method='wilcoxon')
sc.pl.rank_genes_groups(Luecken_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.violin(Luecken_dataset_Train, keys='CD49b', groupby='Consensus_annotation_detailed_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Luecken_data_Train = pd.DataFrame.sparse.from_spmatrix(Luecken_dataset_Train.X, index=Luecken_dataset_Train.obs_names, columns=Luecken_dataset_Train.var_names)
Luecken_data_Test = pd.DataFrame.sparse.from_spmatrix(Luecken_dataset_Test.X, index=Luecken_dataset_Test.obs_names, columns=Luecken_dataset_Test.var_names)
Luecken_data_Cal = pd.DataFrame.sparse.from_spmatrix(Luecken_dataset_Cal.X, index=Luecken_dataset_Cal.obs_names, columns=Luecken_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgG", "IgM", "IgD"]
Luecken_data_Train = Luecken_data_Train.drop(columns=columns_to_drop, errors='ignore')
Luecken_data_Test = Luecken_data_Test.drop(columns=columns_to_drop, errors='ignore')
Luecken_data_Cal = Luecken_data_Cal.drop(columns=columns_to_drop, errors='ignore')

In [None]:
Luecken_dataset_Train.obs['Consensus_annotation_detailed_final'].value_counts()

# Number of cells per partition

In [None]:
# ------- CONFIG -------
label_col = "Consensus_annotation_detailed_final"

partition_colors = {
    "Train": "#023047",  # dark blue
    "Test":  "#ffb703",  # amber
    "Cal":   "#edede9",  # light gray
}

custom_palette = {
    'B Memory': "#68D827", 'B Naive': '#1C511D', 'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD", 'CD4 T Memory': "#C1AF93", 'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317", 'CD8 T Naive': "#4D382E", 'ErP': "#D1235A",
    'Erythroblast': "#F30A1A", 'GMP': "#C5E4FF", 'HSC': '#0079ea', 'MPP': "#79b6ac",
    'Immature B': "#91FF7B", 'LMPP': "#17BECF", 'MAIT': "#BCBD22", 'HSC_MPP': '#0079ea',
    'Myeloid progenitor': "#AEC7E8", 'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D", 'Plasma': "#9DC012", 'Pro-B': "#66BB6A",
    'Small': "#292929", 'cDC1': "#76A7CB", 'cDC2': "#16D2E3", 'GdT': "#EDB416",
    'pDC': "#69FFCB", 'CD4 CTL': "#D7D2CB", 'MEP': "#E364B0", 'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E', 'EoBaMaP': "#728245", 'MkP': "#69424D",
    'Stroma': "#727272", 'Macrophage': "#5F4761", 'ILC': "#F7CF94", 'DnT': "#504423",
    'GdT_DnT': "#B07A2A",
}

# Ensure figures_path exists
try:
    os.makedirs(figures_path, exist_ok=True)
except NameError:
    figures_path = "."
    os.makedirs(figures_path, exist_ok=True)

# ---- FIXED OUTPUT SIZE (Option B) ----
OUT_W_IN, OUT_H_IN = 3.6, 7.2   # inches -> identical pixel size across datasets
OUT_DPI = 300                   # export dpi
# fixed axes rectangle [left, bottom, width, height] in figure fraction
# Increase 'left' if y labels are long; decrease if you want more plot width.
AX_RECT = [0.42, 0.06, 0.56, 0.90]

def _obs_to_df(ad, part_name):
    s = ad.obs[label_col].astype(str).fillna("Unknown")
    return pd.DataFrame({"Partition": part_name, label_col: s})

def plot_partition_counts_hstack_log10_with_dots(ad_train, ad_test, ad_cal, dataset_name: str):
    # sanity
    for ad in (ad_train, ad_test, ad_cal):
        if label_col not in ad.obs:
            raise KeyError(f"'{label_col}' missing in one of the AnnData objects for {dataset_name}")

    # tidy df
    df = pd.concat([
        _obs_to_df(ad_train, "Train"),
        _obs_to_df(ad_test,  "Test"),
        _obs_to_df(ad_cal,   "Cal"),
    ], ignore_index=True)

    # counts: rows = cell type, cols = partition
    count_table = (
        df.groupby([label_col, "Partition"])
          .size()
          .unstack(fill_value=0)
    )

    # order by TRAIN counts, then by total (stable-ish ordering)
    train_counts = count_table.get("Train", pd.Series(0, index=count_table.index))
    total_counts = count_table.sum(axis=1)

    # primary sort: train counts desc
    count_table = count_table.loc[train_counts.sort_values(ascending=False).index]

    # secondary tie-break: total counts desc (keeps stability for equal train counts)
    # use mergesort for stability
    tmp = pd.DataFrame({
        "train": train_counts.reindex(count_table.index),
        "total": total_counts.reindex(count_table.index),
    }, index=count_table.index).sort_values(["train", "total"], ascending=[False, False], kind="mergesort")
    count_table = count_table.loc[tmp.index]

    # --- correct log10 stacking ---
    totals = count_table.sum(axis=1).to_numpy(dtype=float)
    log_totals = np.zeros_like(totals, dtype=float)
    mask = totals > 0
    log_totals[mask] = np.log10(totals[mask])

    labels = count_table.index.to_list()

    # ---- FIXED GEOMETRY FIGURE (Option B) ----
    fig = plt.figure(figsize=(OUT_W_IN, OUT_H_IN), dpi=OUT_DPI)
    ax = fig.add_axes(AX_RECT)  # fixed axes area -> consistent geometry across datasets

    left = np.zeros_like(log_totals)

    for part in ["Train", "Test", "Cal"]:
        counts = count_table[part].to_numpy(dtype=float) if part in count_table.columns else np.zeros_like(totals)
        frac = np.divide(counts, totals, out=np.zeros_like(counts), where=mask)
        seg = log_totals * frac  # segment width proportional to composition
        ax.barh(
            labels,
            seg,
            left=left,
            color=partition_colors[part],
            label=part,
            edgecolor="black",
            linewidth=0.8,
        )
        left += seg

    # style
    ax.set_xlabel("log₁₀(cells)")
    ax.set_ylabel("")
    ax.spines[['top', 'right']].set_visible(False)
    ax.tick_params(axis='y', labelsize=8)
    ax.invert_yaxis()

    # x-axis ticks and vertical guides
    tick_vals = [1, 2, 3, 4, 5]
    xmax_data = float(np.nanmax(log_totals)) if len(log_totals) else 0.0
    xmax = max(xmax_data, max(tick_vals))
    ax.set_xlim(0, xmax * 1.05 if xmax > 0 else 1)

    ax.set_xticks(tick_vals)
    ax.set_xticklabels([str(t) for t in tick_vals])

    for xv in tick_vals:
        ax.axvline(x=xv, linestyle="--", linewidth=0.6, color="#BDBDBD", alpha=0.8, zorder=0)

    # dots between label text and bars, but placed consistently using axes coords for x
    dot_x_axes = 0.02  # 2% inside the axes from the left
    y_transform = ax.get_yaxis_transform()  # x in axes coords, y in data coords
    for y, label in enumerate(labels):
        color = custom_palette.get(label, "#BBBBBB")
        ax.scatter(
            dot_x_axes, y,
            s=22,
            color=color,
            edgecolor="black",
            linewidth=0.3,
            zorder=5,
            clip_on=False,
            transform=y_transform
        )

    # legend bottom-right
    ax.legend(
        loc="lower right",
        frameon=False,
        fontsize=7,
        ncol=1,
        handlelength=1.2,
        handleheight=0.8,
        labelspacing=0.3,
        borderaxespad=0.5
    )

    # IMPORTANT: do NOT use tight_layout or bbox_inches='tight' if you want identical output size
    out_png = os.path.join(figures_path, f"N_celltypes_{dataset_name}.png")
    fig.savefig(out_png, dpi=OUT_DPI)  # fixed pixel size
    plt.close(fig)
    print(f"[INFO] Saved: {out_png}")

# -------- RUN FOR THE FOUR DATASETS (already loaded) --------
plot_partition_counts_hstack_log10_with_dots(
    Luecken_dataset_Train, Triana_dataset_Test if False else Luecken_dataset_Test,
    Luecken_dataset_Cal, dataset_name="Luecken"
)
plot_partition_counts_hstack_log10_with_dots(
    Triana_dataset_Train, Triana_dataset_Test, Triana_dataset_Cal, dataset_name="Triana"
)
plot_partition_counts_hstack_log10_with_dots(
    Hao_dataset_Train, Hao_dataset_Test, Hao_dataset_Cal, dataset_name="Hao"
)
plot_partition_counts_hstack_log10_with_dots(
    Zhang_dataset_Train, Zhang_dataset_Test, Zhang_dataset_Cal, dataset_name="Zhang"
)


# Loading antibodies panels

In [None]:
TotalSeqD_Heme_Oncology_CAT399906 = pd.read_csv(data_path + "/Antibodies_panels/TotalSeqD_Heme_Oncology_CAT399906.csv", index_col=0).index

# TotalSeqD Heme Oncology CAT399906 Models Training

In [None]:
data_path = data_path + '/Pre_trained_models/TotalSeqD_Heme_Oncology_CAT399906'
os.makedirs(data_path, exist_ok=True)

## Hao Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Hao", exist_ok=True)
os.makedirs(data_path + "/Hao/Dev", exist_ok=True)
os.makedirs(data_path + "/Hao/Release", exist_ok=True)
os.makedirs(data_path + "/Hao/Dev/Models", exist_ok=True)
os.makedirs(data_path + "/Hao/Release/Models", exist_ok=True)

models_output = data_path + "/Hao"

### ML Training

In [None]:
Hao_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# NOTE:
#   Many helper functions are now provided by MLTraining.py.
#   This script should primarily orchestrate: data prep -> model training loop -> exports.
# =============================================================================

from pathlib import Path
import joblib
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {
    'Immature': "#0079ea", 
    'Mature': "#AF3434"
}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    'HSC_MPP':            '#0079ea',
    'LMPP':               "#17BECF",
    'GMP':                "#C5E4FF",
    'Myeloid progenitor': "#AEC7E8",
    'Monocyte':           "#D27CE3",
    'CD14 Mono':         "#D27CE3",
    'CD16 Mono':         "#8D43CD",
    'Erythroblast':      "#F30A1A",
    'ErP':               "#D1235A",
    'MEP':               "#E364B0",
    'CD4 T Naive':       "#C99546",
    'CD4 T Memory':      "#C1AF93",
    'CD8 T Naive':       "#4D382E",
    'CD8 T Memory':      "#6B3317",
    'Other_T':           "#EDB416",
    'Treg':              "#6E6C37",
    'B Naive':          '#1C511D',
    'B Memory':         "#68D827",
    'Pro-B':            "#66BB6A",
    'Pre-B':            "#2DBD67",
    'Immature B':      "#91FF7B",
    'Plasma':           "#9DC012",
    'cDC1':             "#76A7CB",
    'cDC2':             "#16D2E3",
    'pDC':              "#69FFCB",
    'NK CD56 bright':  "#F3AC1F",
    'NK CD56 dim':     "#FBEF0D",
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False


# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True    # set False to disable Release outputs
EXPORT_DEV     = False   # set True to enable Dev outputs


# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Broad"  # "Broad" | "Simplified" | "Detailed"
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
# Choose where to read the 2D embedding from.
# Supported:
#   - "adata_obsm": read from adata_train.obsm[obsm_key]
#   - "adata_obs":  read from adata_train.obs[[obs_x, obs_y]]
#   - "train_df":   read from train_df[[df_x, df_y]] (e.g., Hao_data_Train has UMAP columns)
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"

# If EMBEDDING_SOURCE == "adata_obsm"
EMBEDDING_OBSM_KEY = "X_wnn.umap"     # e.g. "X_umap", "X_pca"

# If EMBEDDING_SOURCE == "adata_obs"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"

# If EMBEDDING_SOURCE == "train_df"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"


# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
hao_root = Path(models_output)

dev_root     = hao_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir    = models_root / "heads"
metrics_dir  = reports_root / "metrics"
probs_dir    = reports_root / "probabilities"
fig_percls   = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root     = hao_root / "Release"
release_models   = release_root / name_target_class / "Models"
release_reports  = release_root / name_target_class / "Reports"
release_metrics  = release_reports / "Metrics"
release_probs    = release_reports / "Probabilities"
release_imps     = release_reports / "Importances"
release_figs     = release_root / name_target_class / "Figures"
release_single   = release_figs / "Single_classes"

# Create directories conditionally
if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] DEV Models:  {models_root}")
    print(f"[INFO] DEV Reports: {reports_root}")
    print(f"[INFO] DEV Figures: {fig_root}")

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")


# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Hao_data_Train = MLTraining.attach_celltype(Hao_data_Train, Hao_dataset_Train, consensus_field)
Hao_data_Test  = MLTraining.attach_celltype(Hao_data_Test,  Hao_dataset_Test,  consensus_field)
Hao_data_Cal   = MLTraining.attach_celltype(Hao_data_Cal,   Hao_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")


# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Hao_data_Train = panel_intersection(rename_data_to_panel(Hao_data_Train))
Hao_data_Test  = panel_intersection(rename_data_to_panel(Hao_data_Test))
Hao_data_Cal   = panel_intersection(rename_data_to_panel(Hao_data_Cal))

print("  ✓ Data columns now aligned to panel (panel order preserved)")


# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Hao_data_Cal_lbl = Hao_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Cal.columns]

Hao_data_Train_Sub = Hao_data_Train.drop(columns=drop_cols_train, errors="ignore")
Hao_data_Test_Sub  = Hao_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Hao_data_Cal_Sub   = Hao_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Hao_data_Train_Sub.columns)
if list(Hao_data_Test_Sub.columns) != cols_train or list(Hao_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Hao_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Hao_data_Test_Sub,  "TEST")
MLTraining.check_finite(Hao_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

class_names  = sorted(pd.Series(Hao_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes")

s_cal = Hao_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Hao_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Hao_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Hao_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

X_cal_all_df = Hao_data_Cal_Sub.copy()
X_te_all_df  = Hao_data_Test_Sub.copy()
test_index   = Hao_data_Test_Sub.index

P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

# Importances collectors
xgb_shap_rows = []       # mean_abs + corr (later filtered top10/class)
lr_contrib_rows = []     # LR base learner contributions (from stacker_raw)
platt_metrics_rows = []  # per-class logloss/brier pre vs post platt


# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Hao/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Hao_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Hao_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN UMAP (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=PALETTE_BROAD,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Hao_dataset_Train,
            train_df=Hao_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=(str(celltype) == "Mature"),
        )

    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Hao/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Hao_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Hao_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Hao_data_Test["Celltype"].values == celltype).astype(int)

    # CAL split for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Hao_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train
        )

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Hao",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions: keep your existing helper for now if not moved
    # If you have moved this helper into MLTraining.py, replace call accordingly.
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)  # existing in notebook
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")


# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)


# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)
if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)


# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Hao",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")


# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

# SHAP export (Top10/class; keep corr_feature_value_vs_shap)
shap_df = None
if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)

    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )

    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )

    keep_cols = [
        "depth", "class_name", "dataset",
        "feature", "mean_abs_shap", "corr_feature_value_vs_shap",
        "rank_within_class",
    ]
    shap_df = shap_df[keep_cols]

    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

# LR export
lr_df = None
if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")


# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Hao_data_Test["Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]

    probs_dev_path = probs_dir / "probabilities_before_after_TEST.csv"
    probs_dev.to_csv(probs_dev_path, index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Hao_data_Test["Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values

    release_probs_path = release_probs / "Multiclass_models_probabilities_on_test.csv"
    probs_release.to_csv(release_probs_path, index=True)


# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)

report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).T

cm_mc_df = pd.DataFrame(
    cm_mc,
    index=pd.Index(class_names, name="true"),
    columns=pd.Index(class_names, name="pred"),
)

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")


# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")

if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []

y_pred_raw = P_te_raw.argmax(axis=1)
y_pred_cal = P_te_cal.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    if EXPORT_DEV:
        idx = pd.Index(["True=Other", f"True={cls}"], name="true")
        cols = pd.Index(["Pred=Other", f"Pred={cls}"], name="pred")
        pd.DataFrame(cm_raw, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_RAW.csv")
        pd.DataFrame(cm_cal, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_CAL.csv")

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(
        y_true_bin,
        score_raw,
        f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)",
        dev_out,
        rel_out,
    )

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(
        y_true_bin,
        score_cal,
        f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)",
        dev_out,
        rel_out,
    )

    m_raw = _metrics_from_cm(cm_raw)
    m_raw.update(model="RAW", class_name=cls, auc=auc_raw)
    per_class_rows.append(m_raw)

    m_cal = _metrics_from_cm(cm_cal)
    m_cal.update(model="CAL", class_name=cls, auc=auc_cal)
    per_class_rows.append(m_cal)

if EXPORT_DEV:
    print(f"  ✓ Saved per-class plots (DEV) → {fig_percls}")
if EXPORT_RELEASE:
    print(f"  ✓ Saved per-class plots (RELEASE) → {release_single}")


# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    dev_metrics_path = metrics_dir / "per_class_argmax_metrics_TEST_included.csv"
    per_class_df.to_csv(dev_metrics_path, index=False)
    print(f"  ✓ Saved DEV per-class metrics → {dev_metrics_path}")

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)
    print(f"  ✓ Saved RELEASE per-class metrics → {out_single}")

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")
    print(f"  ✓ Appended DEV binary-head metrics → {dev_root / 'stacker_metrics.csv'}")

print("\n✅ BROAD PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Simplified"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_wnn.umap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
hao_root = Path(models_output)

dev_root     = hao_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = hao_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Hao_data_Train = MLTraining.attach_celltype(Hao_data_Train, Hao_dataset_Train, consensus_field)
Hao_data_Test  = MLTraining.attach_celltype(Hao_data_Test,  Hao_dataset_Test,  consensus_field)
Hao_data_Cal   = MLTraining.attach_celltype(Hao_data_Cal,   Hao_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Hao_data_Train = panel_intersection(rename_data_to_panel(Hao_data_Train))
Hao_data_Test  = panel_intersection(rename_data_to_panel(Hao_data_Test))
Hao_data_Cal   = panel_intersection(rename_data_to_panel(Hao_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Hao_data_Cal_lbl = Hao_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Cal.columns]

Hao_data_Train_Sub = Hao_data_Train.drop(columns=drop_cols_train, errors="ignore")
Hao_data_Test_Sub  = Hao_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Hao_data_Cal_Sub   = Hao_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Hao_data_Train_Sub.columns)
if list(Hao_data_Test_Sub.columns) != cols_train or list(Hao_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Hao_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Hao_data_Test_Sub,  "TEST")
MLTraining.check_finite(Hao_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Hao_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Hao_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Hao_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Hao_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Hao_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Hao_data_Cal_lbl_f  = Hao_data_Cal_lbl.loc[cal_keep_mask].copy()
Hao_data_Test_lbl_f = Hao_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Hao_data_Cal_Sub.loc[Hao_data_Cal_lbl_f.index].copy()
X_te_all_df  = Hao_data_Test_Sub.loc[Hao_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Hao_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Hao_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Hao_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Hao_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Hao/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Hao_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Hao_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Hao_dataset_Train,
            train_df=Hao_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Hao/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Hao_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Hao_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Hao_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Hao_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Hao",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Hao",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Hao_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Hao_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ SIMPLIFIED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Detailed"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_wnn.umap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
hao_root = Path(models_output)

dev_root     = hao_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = hao_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Hao_data_Train = MLTraining.attach_celltype(Hao_data_Train, Hao_dataset_Train, consensus_field)
Hao_data_Test  = MLTraining.attach_celltype(Hao_data_Test,  Hao_dataset_Test,  consensus_field)
Hao_data_Cal   = MLTraining.attach_celltype(Hao_data_Cal,   Hao_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Hao_data_Train = panel_intersection(rename_data_to_panel(Hao_data_Train))
Hao_data_Test  = panel_intersection(rename_data_to_panel(Hao_data_Test))
Hao_data_Cal   = panel_intersection(rename_data_to_panel(Hao_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Hao_data_Cal_lbl = Hao_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Cal.columns]

Hao_data_Train_Sub = Hao_data_Train.drop(columns=drop_cols_train, errors="ignore")
Hao_data_Test_Sub  = Hao_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Hao_data_Cal_Sub   = Hao_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Hao_data_Train_Sub.columns)
if list(Hao_data_Test_Sub.columns) != cols_train or list(Hao_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Hao_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Hao_data_Test_Sub,  "TEST")
MLTraining.check_finite(Hao_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Hao_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Hao_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Hao_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Hao_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Hao_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Hao_data_Cal_lbl_f  = Hao_data_Cal_lbl.loc[cal_keep_mask].copy()
Hao_data_Test_lbl_f = Hao_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Hao_data_Cal_Sub.loc[Hao_data_Cal_lbl_f.index].copy()
X_te_all_df  = Hao_data_Test_Sub.loc[Hao_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Hao_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Hao_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Hao_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Hao_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Hao/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Hao_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Hao_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Hao_dataset_Train,
            train_df=Hao_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Hao/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Hao_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Hao_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Hao_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Hao_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Hao",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Hao",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Hao_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Hao_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ DETAILED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


## Zhang Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Zhang", exist_ok=True)
os.makedirs(data_path + "/Zhang/Dev", exist_ok=True)
os.makedirs(data_path + "/Zhang/Release", exist_ok=True)
os.makedirs(data_path + "/Zhang/Dev/Models", exist_ok=True)
os.makedirs(data_path + "/Zhang/Release/Models", exist_ok=True)

models_output = data_path + "/Zhang"

### ML Training

In [None]:
Zhang_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# NOTE:
#   Many helper functions are now provided by MLTraining.py.
#   This script should primarily orchestrate: data prep -> model training loop -> exports.
# =============================================================================

from pathlib import Path
import joblib
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {
    'Immature': "#0079ea", 
    'Mature': "#AF3434"
}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    'HSC_MPP':            '#0079ea',
    'LMPP':               "#17BECF",
    'GMP':                "#C5E4FF",
    'Myeloid progenitor': "#AEC7E8",
    'Monocyte':           "#D27CE3",
    'CD14 Mono':         "#D27CE3",
    'CD16 Mono':         "#8D43CD",
    'Erythroblast':      "#F30A1A",
    'ErP':               "#D1235A",
    'MEP':               "#E364B0",
    'CD4 T Naive':       "#C99546",
    'CD4 T Memory':      "#C1AF93",
    'CD8 T Naive':       "#4D382E",
    'CD8 T Memory':      "#6B3317",
    'Other_T':           "#EDB416",
    'Treg':              "#6E6C37",
    'B Naive':          '#1C511D',
    'B Memory':         "#68D827",
    'Pro-B':            "#66BB6A",
    'Pre-B':            "#2DBD67",
    'Immature B':      "#91FF7B",
    'Plasma':           "#9DC012",
    'cDC1':             "#76A7CB",
    'cDC2':             "#16D2E3",
    'pDC':              "#69FFCB",
    'NK CD56 bright':  "#F3AC1F",
    'NK CD56 dim':     "#FBEF0D",
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False


# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True    # set False to disable Release outputs
EXPORT_DEV     = False   # set True to enable Dev outputs


# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Broad"  # "Broad" | "Simplified" | "Detailed"
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
# Choose where to read the 2D embedding from.
# Supported:
#   - "adata_obsm": read from adata_train.obsm[obsm_key]
#   - "adata_obs":  read from adata_train.obs[[obs_x, obs_y]]
#   - "train_df":   read from train_df[[df_x, df_y]] (e.g., Zhang_data_Train has UMAP columns)
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"

# If EMBEDDING_SOURCE == "adata_obsm"
EMBEDDING_OBSM_KEY = "X_umap"     # e.g. "X_umap", "X_pca"

# If EMBEDDING_SOURCE == "adata_obs"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"

# If EMBEDDING_SOURCE == "train_df"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"


# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Zhang_root = Path(models_output)

dev_root     = Zhang_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir    = models_root / "heads"
metrics_dir  = reports_root / "metrics"
probs_dir    = reports_root / "probabilities"
fig_percls   = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root     = Zhang_root / "Release"
release_models   = release_root / name_target_class / "Models"
release_reports  = release_root / name_target_class / "Reports"
release_metrics  = release_reports / "Metrics"
release_probs    = release_reports / "Probabilities"
release_imps     = release_reports / "Importances"
release_figs     = release_root / name_target_class / "Figures"
release_single   = release_figs / "Single_classes"

# Create directories conditionally
if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] DEV Models:  {models_root}")
    print(f"[INFO] DEV Reports: {reports_root}")
    print(f"[INFO] DEV Figures: {fig_root}")

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")


# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Zhang_data_Train = MLTraining.attach_celltype(Zhang_data_Train, Zhang_dataset_Train, consensus_field)
Zhang_data_Test  = MLTraining.attach_celltype(Zhang_data_Test,  Zhang_dataset_Test,  consensus_field)
Zhang_data_Cal   = MLTraining.attach_celltype(Zhang_data_Cal,   Zhang_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")


# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Zhang_data_Train = panel_intersection(rename_data_to_panel(Zhang_data_Train))
Zhang_data_Test  = panel_intersection(rename_data_to_panel(Zhang_data_Test))
Zhang_data_Cal   = panel_intersection(rename_data_to_panel(Zhang_data_Cal))

print("  ✓ Data columns now aligned to panel (panel order preserved)")


# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Zhang_data_Cal_lbl = Zhang_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Cal.columns]

Zhang_data_Train_Sub = Zhang_data_Train.drop(columns=drop_cols_train, errors="ignore")
Zhang_data_Test_Sub  = Zhang_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Zhang_data_Cal_Sub   = Zhang_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Zhang_data_Train_Sub.columns)
if list(Zhang_data_Test_Sub.columns) != cols_train or list(Zhang_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Zhang_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Zhang_data_Test_Sub,  "TEST")
MLTraining.check_finite(Zhang_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

class_names  = sorted(pd.Series(Zhang_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes")

s_cal = Zhang_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Zhang_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Zhang_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Zhang_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

X_cal_all_df = Zhang_data_Cal_Sub.copy()
X_te_all_df  = Zhang_data_Test_Sub.copy()
test_index   = Zhang_data_Test_Sub.index

P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

# Importances collectors
xgb_shap_rows = []       # mean_abs + corr (later filtered top10/class)
lr_contrib_rows = []     # LR base learner contributions (from stacker_raw)
platt_metrics_rows = []  # per-class logloss/brier pre vs post platt


# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Zhang/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Zhang_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Zhang_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN UMAP (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=PALETTE_BROAD,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Zhang_dataset_Train,
            train_df=Zhang_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=(str(celltype) == "Mature"),
        )

    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Zhang/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Zhang_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Zhang_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Zhang_data_Test["Celltype"].values == celltype).astype(int)

    # CAL split for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Zhang_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train
        )

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Zhang",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions: keep your existing helper for now if not moved
    # If you have moved this helper into MLTraining.py, replace call accordingly.
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)  # existing in notebook
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")


# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)


# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)
if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)


# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Zhang",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")


# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

# SHAP export (Top10/class; keep corr_feature_value_vs_shap)
shap_df = None
if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)

    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )

    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )

    keep_cols = [
        "depth", "class_name", "dataset",
        "feature", "mean_abs_shap", "corr_feature_value_vs_shap",
        "rank_within_class",
    ]
    shap_df = shap_df[keep_cols]

    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

# LR export
lr_df = None
if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")


# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Zhang_data_Test["Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]

    probs_dev_path = probs_dir / "probabilities_before_after_TEST.csv"
    probs_dev.to_csv(probs_dev_path, index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Zhang_data_Test["Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values

    release_probs_path = release_probs / "Multiclass_models_probabilities_on_test.csv"
    probs_release.to_csv(release_probs_path, index=True)


# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)

report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).T

cm_mc_df = pd.DataFrame(
    cm_mc,
    index=pd.Index(class_names, name="true"),
    columns=pd.Index(class_names, name="pred"),
)

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")


# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")

if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []

y_pred_raw = P_te_raw.argmax(axis=1)
y_pred_cal = P_te_cal.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    if EXPORT_DEV:
        idx = pd.Index(["True=Other", f"True={cls}"], name="true")
        cols = pd.Index(["Pred=Other", f"Pred={cls}"], name="pred")
        pd.DataFrame(cm_raw, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_RAW.csv")
        pd.DataFrame(cm_cal, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_CAL.csv")

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(
        y_true_bin,
        score_raw,
        f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)",
        dev_out,
        rel_out,
    )

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(
        y_true_bin,
        score_cal,
        f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)",
        dev_out,
        rel_out,
    )

    m_raw = _metrics_from_cm(cm_raw)
    m_raw.update(model="RAW", class_name=cls, auc=auc_raw)
    per_class_rows.append(m_raw)

    m_cal = _metrics_from_cm(cm_cal)
    m_cal.update(model="CAL", class_name=cls, auc=auc_cal)
    per_class_rows.append(m_cal)

if EXPORT_DEV:
    print(f"  ✓ Saved per-class plots (DEV) → {fig_percls}")
if EXPORT_RELEASE:
    print(f"  ✓ Saved per-class plots (RELEASE) → {release_single}")


# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    dev_metrics_path = metrics_dir / "per_class_argmax_metrics_TEST_included.csv"
    per_class_df.to_csv(dev_metrics_path, index=False)
    print(f"  ✓ Saved DEV per-class metrics → {dev_metrics_path}")

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)
    print(f"  ✓ Saved RELEASE per-class metrics → {out_single}")

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")
    print(f"  ✓ Appended DEV binary-head metrics → {dev_root / 'stacker_metrics.csv'}")

print("\n✅ BROAD PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Simplified"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_umap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Zhang_root = Path(models_output)

dev_root     = Zhang_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = Zhang_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Zhang_data_Train = MLTraining.attach_celltype(Zhang_data_Train, Zhang_dataset_Train, consensus_field)
Zhang_data_Test  = MLTraining.attach_celltype(Zhang_data_Test,  Zhang_dataset_Test,  consensus_field)
Zhang_data_Cal   = MLTraining.attach_celltype(Zhang_data_Cal,   Zhang_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Zhang_data_Train = panel_intersection(rename_data_to_panel(Zhang_data_Train))
Zhang_data_Test  = panel_intersection(rename_data_to_panel(Zhang_data_Test))
Zhang_data_Cal   = panel_intersection(rename_data_to_panel(Zhang_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Zhang_data_Cal_lbl = Zhang_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Cal.columns]

Zhang_data_Train_Sub = Zhang_data_Train.drop(columns=drop_cols_train, errors="ignore")
Zhang_data_Test_Sub  = Zhang_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Zhang_data_Cal_Sub   = Zhang_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Zhang_data_Train_Sub.columns)
if list(Zhang_data_Test_Sub.columns) != cols_train or list(Zhang_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Zhang_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Zhang_data_Test_Sub,  "TEST")
MLTraining.check_finite(Zhang_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Zhang_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Zhang_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Zhang_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Zhang_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Zhang_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Zhang_data_Cal_lbl_f  = Zhang_data_Cal_lbl.loc[cal_keep_mask].copy()
Zhang_data_Test_lbl_f = Zhang_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Zhang_data_Cal_Sub.loc[Zhang_data_Cal_lbl_f.index].copy()
X_te_all_df  = Zhang_data_Test_Sub.loc[Zhang_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Zhang_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Zhang_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Zhang_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Zhang_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Zhang/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Zhang_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Zhang_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Zhang_dataset_Train,
            train_df=Zhang_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Zhang/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Zhang_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Zhang_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Zhang_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Zhang_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Zhang",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Zhang",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Zhang_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Zhang_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ SIMPLIFIED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Detailed"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_umap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Zhang_root = Path(models_output)

dev_root     = Zhang_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = Zhang_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Zhang_data_Train = MLTraining.attach_celltype(Zhang_data_Train, Zhang_dataset_Train, consensus_field)
Zhang_data_Test  = MLTraining.attach_celltype(Zhang_data_Test,  Zhang_dataset_Test,  consensus_field)
Zhang_data_Cal   = MLTraining.attach_celltype(Zhang_data_Cal,   Zhang_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Zhang_data_Train = panel_intersection(rename_data_to_panel(Zhang_data_Train))
Zhang_data_Test  = panel_intersection(rename_data_to_panel(Zhang_data_Test))
Zhang_data_Cal   = panel_intersection(rename_data_to_panel(Zhang_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Zhang_data_Cal_lbl = Zhang_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Cal.columns]

Zhang_data_Train_Sub = Zhang_data_Train.drop(columns=drop_cols_train, errors="ignore")
Zhang_data_Test_Sub  = Zhang_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Zhang_data_Cal_Sub   = Zhang_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Zhang_data_Train_Sub.columns)
if list(Zhang_data_Test_Sub.columns) != cols_train or list(Zhang_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Zhang_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Zhang_data_Test_Sub,  "TEST")
MLTraining.check_finite(Zhang_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Zhang_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Zhang_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Zhang_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Zhang_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Zhang_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Zhang_data_Cal_lbl_f  = Zhang_data_Cal_lbl.loc[cal_keep_mask].copy()
Zhang_data_Test_lbl_f = Zhang_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Zhang_data_Cal_Sub.loc[Zhang_data_Cal_lbl_f.index].copy()
X_te_all_df  = Zhang_data_Test_Sub.loc[Zhang_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Zhang_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Zhang_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Zhang_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Zhang_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Zhang/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Zhang_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Zhang_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Zhang_dataset_Train,
            train_df=Zhang_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Zhang/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Zhang_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Zhang_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Zhang_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Zhang_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Zhang",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Zhang",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Zhang_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Zhang_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ DETAILED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


## Triana Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Triana", exist_ok=True)
os.makedirs(data_path + "/Triana/Dev", exist_ok=True)
os.makedirs(data_path + "/Triana/Release", exist_ok=True)
os.makedirs(data_path + "/Triana/Dev/Models", exist_ok=True)
os.makedirs(data_path + "/Triana/Release/Models", exist_ok=True)

models_output = data_path + "/Triana"

### ML Training

In [None]:
Triana_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# NOTE:
#   Many helper functions are now provided by MLTraining.py.
#   This script should primarily orchestrate: data prep -> model training loop -> exports.
# =============================================================================

from pathlib import Path
import joblib
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {
    'Immature': "#0079ea", 
    'Mature': "#AF3434"
}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    'HSC_MPP':            '#0079ea',
    'LMPP':               "#17BECF",
    'GMP':                "#C5E4FF",
    'Myeloid progenitor': "#AEC7E8",
    'Monocyte':           "#D27CE3",
    'CD14 Mono':         "#D27CE3",
    'CD16 Mono':         "#8D43CD",
    'Erythroblast':      "#F30A1A",
    'ErP':               "#D1235A",
    'MEP':               "#E364B0",
    'CD4 T Naive':       "#C99546",
    'CD4 T Memory':      "#C1AF93",
    'CD8 T Naive':       "#4D382E",
    'CD8 T Memory':      "#6B3317",
    'Other_T':           "#EDB416",
    'Treg':              "#6E6C37",
    'B Naive':          '#1C511D',
    'B Memory':         "#68D827",
    'Pro-B':            "#66BB6A",
    'Pre-B':            "#2DBD67",
    'Immature B':      "#91FF7B",
    'Plasma':           "#9DC012",
    'cDC1':             "#76A7CB",
    'cDC2':             "#16D2E3",
    'pDC':              "#69FFCB",
    'NK CD56 bright':  "#F3AC1F",
    'NK CD56 dim':     "#FBEF0D",
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False


# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True    # set False to disable Release outputs
EXPORT_DEV     = False   # set True to enable Dev outputs


# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Broad"  # "Broad" | "Simplified" | "Detailed"
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
# Choose where to read the 2D embedding from.
# Supported:
#   - "adata_obsm": read from adata_train.obsm[obsm_key]
#   - "adata_obs":  read from adata_train.obs[[obs_x, obs_y]]
#   - "train_df":   read from train_df[[df_x, df_y]] (e.g., Triana_data_Train has UMAP columns)
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"

# If EMBEDDING_SOURCE == "adata_obsm"
EMBEDDING_OBSM_KEY = "X_mofaumap"     # e.g. "X_umap", "X_pca"

# If EMBEDDING_SOURCE == "adata_obs"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"

# If EMBEDDING_SOURCE == "train_df"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"


# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Triana_root = Path(models_output)

dev_root     = Triana_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir    = models_root / "heads"
metrics_dir  = reports_root / "metrics"
probs_dir    = reports_root / "probabilities"
fig_percls   = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root     = Triana_root / "Release"
release_models   = release_root / name_target_class / "Models"
release_reports  = release_root / name_target_class / "Reports"
release_metrics  = release_reports / "Metrics"
release_probs    = release_reports / "Probabilities"
release_imps     = release_reports / "Importances"
release_figs     = release_root / name_target_class / "Figures"
release_single   = release_figs / "Single_classes"

# Create directories conditionally
if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] DEV Models:  {models_root}")
    print(f"[INFO] DEV Reports: {reports_root}")
    print(f"[INFO] DEV Figures: {fig_root}")

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")


# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Triana_data_Train = MLTraining.attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = MLTraining.attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = MLTraining.attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")


# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = panel_intersection(rename_data_to_panel(Triana_data_Train))
Triana_data_Test  = panel_intersection(rename_data_to_panel(Triana_data_Test))
Triana_data_Cal   = panel_intersection(rename_data_to_panel(Triana_data_Cal))

print("  ✓ Data columns now aligned to panel (panel order preserved)")


# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Triana_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Triana_data_Test_Sub,  "TEST")
MLTraining.check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

class_names  = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes")

s_cal = Triana_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Triana_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

X_cal_all_df = Triana_data_Cal_Sub.copy()
X_te_all_df  = Triana_data_Test_Sub.copy()
test_index   = Triana_data_Test_Sub.index

P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

# Importances collectors
xgb_shap_rows = []       # mean_abs + corr (later filtered top10/class)
lr_contrib_rows = []     # LR base learner contributions (from stacker_raw)
platt_metrics_rows = []  # per-class logloss/brier pre vs post platt


# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN UMAP (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=PALETTE_BROAD,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Triana_dataset_Train,
            train_df=Triana_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=(str(celltype) == "Mature"),
        )

    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Triana_data_Test["Celltype"].values == celltype).astype(int)

    # CAL split for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Triana_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train
        )

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions: keep your existing helper for now if not moved
    # If you have moved this helper into MLTraining.py, replace call accordingly.
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)  # existing in notebook
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")


# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)


# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)
if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)


# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Triana",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")


# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

# SHAP export (Top10/class; keep corr_feature_value_vs_shap)
shap_df = None
if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)

    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )

    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )

    keep_cols = [
        "depth", "class_name", "dataset",
        "feature", "mean_abs_shap", "corr_feature_value_vs_shap",
        "rank_within_class",
    ]
    shap_df = shap_df[keep_cols]

    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

# LR export
lr_df = None
if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")


# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Triana_data_Test["Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]

    probs_dev_path = probs_dir / "probabilities_before_after_TEST.csv"
    probs_dev.to_csv(probs_dev_path, index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Triana_data_Test["Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values

    release_probs_path = release_probs / "Multiclass_models_probabilities_on_test.csv"
    probs_release.to_csv(release_probs_path, index=True)


# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)

report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).T

cm_mc_df = pd.DataFrame(
    cm_mc,
    index=pd.Index(class_names, name="true"),
    columns=pd.Index(class_names, name="pred"),
)

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")


# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")

if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []

y_pred_raw = P_te_raw.argmax(axis=1)
y_pred_cal = P_te_cal.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    if EXPORT_DEV:
        idx = pd.Index(["True=Other", f"True={cls}"], name="true")
        cols = pd.Index(["Pred=Other", f"Pred={cls}"], name="pred")
        pd.DataFrame(cm_raw, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_RAW.csv")
        pd.DataFrame(cm_cal, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_CAL.csv")

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(
        y_true_bin,
        score_raw,
        f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)",
        dev_out,
        rel_out,
    )

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(
        y_true_bin,
        score_cal,
        f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)",
        dev_out,
        rel_out,
    )

    m_raw = _metrics_from_cm(cm_raw)
    m_raw.update(model="RAW", class_name=cls, auc=auc_raw)
    per_class_rows.append(m_raw)

    m_cal = _metrics_from_cm(cm_cal)
    m_cal.update(model="CAL", class_name=cls, auc=auc_cal)
    per_class_rows.append(m_cal)

if EXPORT_DEV:
    print(f"  ✓ Saved per-class plots (DEV) → {fig_percls}")
if EXPORT_RELEASE:
    print(f"  ✓ Saved per-class plots (RELEASE) → {release_single}")


# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    dev_metrics_path = metrics_dir / "per_class_argmax_metrics_TEST_included.csv"
    per_class_df.to_csv(dev_metrics_path, index=False)
    print(f"  ✓ Saved DEV per-class metrics → {dev_metrics_path}")

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)
    print(f"  ✓ Saved RELEASE per-class metrics → {out_single}")

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")
    print(f"  ✓ Appended DEV binary-head metrics → {dev_root / 'stacker_metrics.csv'}")

print("\n✅ BROAD PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Simplified"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_mofaumap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Triana_root = Path(models_output)

dev_root     = Triana_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = Triana_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Triana_data_Train = MLTraining.attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = MLTraining.attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = MLTraining.attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = panel_intersection(rename_data_to_panel(Triana_data_Train))
Triana_data_Test  = panel_intersection(rename_data_to_panel(Triana_data_Test))
Triana_data_Cal   = panel_intersection(rename_data_to_panel(Triana_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Triana_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Triana_data_Test_Sub,  "TEST")
MLTraining.check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Triana_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Triana_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Triana_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Triana_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Triana_data_Cal_lbl_f  = Triana_data_Cal_lbl.loc[cal_keep_mask].copy()
Triana_data_Test_lbl_f = Triana_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Triana_data_Cal_Sub.loc[Triana_data_Cal_lbl_f.index].copy()
X_te_all_df  = Triana_data_Test_Sub.loc[Triana_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Triana_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Triana_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Triana_dataset_Train,
            train_df=Triana_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Triana_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Triana_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Triana",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Triana_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Triana_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ SIMPLIFIED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


In [None]:
from pathlib import Path
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib import rc_context

# --------- 1) Collect training barcodes (union across classes or a single class) ---------
def collect_training_barcodes(train_barcodes_path: str | Path,
                              atlas: str = "Triana",
                              depth: str = "simplified",
                              class_name: str | None = None) -> pd.Index:
    """
    Reads CSVs like:
      {train_barcodes_path}/{atlas}/Consensus_annotation_{depth}_final/Barcodes_training_class_{Class}.csv
    and returns the union of 'Positive' and 'Negative' barcodes.
    """
    base = Path(train_barcodes_path) / atlas / f"Consensus_annotation_{depth.lower()}_final"

    if class_name is None:
        files = sorted(base.glob("Barcodes_training_class_*.csv"))
        if not files:
            raise FileNotFoundError(f"No training CSVs found in: {base}")
    else:
        files = [base / f"Barcodes_training_class_{str(class_name).replace(' ', '_')}.csv"]
        if not files[0].exists():
            raise FileNotFoundError(f"Training CSV not found: {files[0]}")

    barcodes = []
    for fp in files:
        df = pd.read_csv(fp, index_col=0)
        for col in ("Positive", "Negative"):
            if col in df.columns:
                barcodes.extend(df[col].dropna().astype(str).tolist())

    return pd.Index(pd.unique(barcodes))


# --------- 2) Plot UMAP for the subset of training barcodes ---------
def plot_training_subset_umap(
    adata,                                 # Triana_dataset_Train (AnnData)
    train_barcodes_path: str | Path,
    *,
    atlas: str = "Triana",
    depth: str = "simplified",             # matches your CSV folder name
    class_name: str | None = None,         # None -> all classes; or "B Memory", etc.
    label_key: str = "Consensus_annotation_detailed_final",
    basis_key: str = "X_mofaumap",
    custom_palette: dict | None = None,    # your dict {label: "#hex"}
    point_size: float = 10.0,
    title: str | None = None,
):
    # 1) get training barcodes
    keep_barcodes = collect_training_barcodes(train_barcodes_path, atlas=atlas, depth=depth, class_name=class_name)

    # 2) subset AnnData
    n_before = adata.n_obs
    mask = adata.obs_names.isin(keep_barcodes)
    n_after = int(mask.sum())
    if n_after == 0:
        raise ValueError("No training barcodes overlapped with adata.obs_names.")
    ad_sub = adata[mask].copy()
    print(f"[subset] kept {n_after}/{n_before} cells for plotting "
          f"({'ALL classes' if class_name is None else f'class={class_name}'})")

    # 3) ensure categorical for coloring
    if not pd.api.types.is_categorical_dtype(ad_sub.obs[label_key]):
        ad_sub.obs[label_key] = ad_sub.obs[label_key].astype("category")
    cats = list(ad_sub.obs[label_key].cat.categories)

    # 4) palette handling (use your custom palette where available; fallback to grey)
    fallback = "#cccccc"
    if custom_palette is None:
        palette_list = None  # let scanpy pick
    else:
        labels_in_palette = set(custom_palette.keys())
        labels_in_data = set(cats)
        missing_in_palette = [c for c in cats if c not in labels_in_palette]
        extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]
        if missing_in_palette:
            print("[WARN] Missing colors for:", missing_in_palette, "-> using light grey (#cccccc).")
        if extra_in_palette:
            print("[INFO] Palette has unused entries:", extra_in_palette)

        palette_list = [custom_palette.get(c, fallback) for c in cats]
        # stash into .uns so scanpy reuses it
        ad_sub.uns[f"{label_key}_colors"] = palette_list

    # 5) plot
    with rc_context({"figure.figsize": (5.75, 4.75)}):
        sc.pl.embedding(
            ad_sub,
            basis=basis_key,
            color=label_key,
            palette=palette_list,
            legend_loc="on data",
            legend_fontsize=10,
            legend_fontoutline=1.5,
            size=point_size,
            add_outline=True,
            frameon=False,
            title=title or f"Triana — training subset ({'all classes' if class_name is None else class_name})",
            show=True,
        )


# ---------------------- EXAMPLES ----------------------
# 1) Plot ALL training barcodes (union of every class) with your detailed labels + palette
# plot_training_subset_umap(
#     Triana_dataset_Train,
#     train_barcodes_path=train_barcodes_path,
#     atlas="Triana",
#     depth="simplified",
#     class_name=None,
#     label_key="Consensus_annotation_detailed_final",
#     basis_key="X_mofaumap",
#     custom_palette=custom_palette,  # pass the dict you defined
#     point_size=10,
#     title="Triana — all training cells",
# )

# 2) Plot only one class’s training barcodes (e.g., "B Memory")
plot_training_subset_umap(
    Triana_dataset_Train,
    train_barcodes_path=train_barcodes_path,
    atlas="Triana",
    depth="simplified",
    class_name="Erythroid",
    label_key="Consensus_annotation_detailed_final",
    basis_key="X_mofaumap",
    custom_palette=custom_palette,
    point_size=10,
    title="Triana — training cells: B Memory",
)


In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Triana_data_Train, Triana_data_Test, Triana_data_Cal          (DataFrames indexed by barcode)
#   Triana_dataset_Train, Triana_dataset_Test, Triana_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Simplified"   # "simplified" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Triana_data_Train = attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Triana_data_Train = _rename_data_to_panel(Triana_data_Train)
Triana_data_Test  = _rename_data_to_panel(Triana_data_Test)
Triana_data_Cal   = _rename_data_to_panel(Triana_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = _panel_intersection(Triana_data_Train)
Triana_data_Test  = _panel_intersection(Triana_data_Test)
Triana_data_Cal   = _panel_intersection(Triana_data_Cal)

# ============================= FEATURES & LABELS =============================

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Triana_data_Train_Sub, "TRAIN")
_check_finite(Triana_data_Test_Sub,  "TEST")
_check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# Consistent class order
class_names  = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# Multiclass labels arrays
s_cal = Triana_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Triana_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Triana_data_Cal_Sub.copy()
X_te_all_df  = Triana_data_Test_Sub.copy()

# Preallocate OvR prob mats
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Triana_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_simplified_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_simplified_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Triana_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Triana_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal, y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

probs_raw_df = pd.DataFrame(P_te,    index=test_index, columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc, index=test_index, columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Triana_data_Test["Celltype"].values
probs_compare["pred_raw"]      = P_te.argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc.argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS):")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation
y_pred_mc = P_te_mc.argmax(axis=1)
print("\nMulticlass classification report (TEST):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Detailed"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_mofaumap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Triana_root = Path(models_output)

dev_root     = Triana_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = Triana_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Triana_data_Train = MLTraining.attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = MLTraining.attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = MLTraining.attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = panel_intersection(rename_data_to_panel(Triana_data_Train))
Triana_data_Test  = panel_intersection(rename_data_to_panel(Triana_data_Test))
Triana_data_Cal   = panel_intersection(rename_data_to_panel(Triana_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Triana_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Triana_data_Test_Sub,  "TEST")
MLTraining.check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Triana_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Triana_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Triana_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Triana_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Triana_data_Cal_lbl_f  = Triana_data_Cal_lbl.loc[cal_keep_mask].copy()
Triana_data_Test_lbl_f = Triana_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Triana_data_Cal_Sub.loc[Triana_data_Cal_lbl_f.index].copy()
X_te_all_df  = Triana_data_Test_Sub.loc[Triana_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Triana_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Triana_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Triana_dataset_Train,
            train_df=Triana_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Triana_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Triana_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Triana",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Triana_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Triana_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ DETAILED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


## Luecken Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Luecken", exist_ok=True)
os.makedirs(data_path + "/Luecken/Dev", exist_ok=True)
os.makedirs(data_path + "/Luecken/Release", exist_ok=True)
os.makedirs(data_path + "/Luecken/Dev/Models", exist_ok=True)
os.makedirs(data_path + "/Luecken/Release/Models", exist_ok=True)

models_output = data_path + "/Luecken"

### ML Training

In [None]:
Luecken_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# NOTE:
#   Many helper functions are now provided by MLTraining.py.
#   This script should primarily orchestrate: data prep -> model training loop -> exports.
# =============================================================================

from pathlib import Path
import joblib
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {
    'Immature': "#0079ea", 
    'Mature': "#AF3434"
}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    'HSC_MPP':            '#0079ea',
    'LMPP':               "#17BECF",
    'GMP':                "#C5E4FF",
    'Myeloid progenitor': "#AEC7E8",
    'Monocyte':           "#D27CE3",
    'CD14 Mono':         "#D27CE3",
    'CD16 Mono':         "#8D43CD",
    'Erythroblast':      "#F30A1A",
    'ErP':               "#D1235A",
    'MEP':               "#E364B0",
    'CD4 T Naive':       "#C99546",
    'CD4 T Memory':      "#C1AF93",
    'CD8 T Naive':       "#4D382E",
    'CD8 T Memory':      "#6B3317",
    'Other_T':           "#EDB416",
    'Treg':              "#6E6C37",
    'B Naive':          '#1C511D',
    'B Memory':         "#68D827",
    'Pro-B':            "#66BB6A",
    'Pre-B':            "#2DBD67",
    'Immature B':      "#91FF7B",
    'Plasma':           "#9DC012",
    'cDC1':             "#76A7CB",
    'cDC2':             "#16D2E3",
    'pDC':              "#69FFCB",
    'NK CD56 bright':  "#F3AC1F",
    'NK CD56 dim':     "#FBEF0D",
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False


# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True    # set False to disable Release outputs
EXPORT_DEV     = False   # set True to enable Dev outputs


# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Broad"  # "Broad" | "Simplified" | "Detailed"
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
# Choose where to read the 2D embedding from.
# Supported:
#   - "adata_obsm": read from adata_train.obsm[obsm_key]
#   - "adata_obs":  read from adata_train.obs[[obs_x, obs_y]]
#   - "train_df":   read from train_df[[df_x, df_y]] (e.g., Luecken_data_Train has UMAP columns)
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"

# If EMBEDDING_SOURCE == "adata_obsm"
EMBEDDING_OBSM_KEY = "X_umap"     # e.g. "X_umap", "X_pca"

# If EMBEDDING_SOURCE == "adata_obs"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"

# If EMBEDDING_SOURCE == "train_df"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"


# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Luecken_root = Path(models_output)

dev_root     = Luecken_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir    = models_root / "heads"
metrics_dir  = reports_root / "metrics"
probs_dir    = reports_root / "probabilities"
fig_percls   = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root     = Luecken_root / "Release"
release_models   = release_root / name_target_class / "Models"
release_reports  = release_root / name_target_class / "Reports"
release_metrics  = release_reports / "Metrics"
release_probs    = release_reports / "Probabilities"
release_imps     = release_reports / "Importances"
release_figs     = release_root / name_target_class / "Figures"
release_single   = release_figs / "Single_classes"

# Create directories conditionally
if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] DEV Models:  {models_root}")
    print(f"[INFO] DEV Reports: {reports_root}")
    print(f"[INFO] DEV Figures: {fig_root}")

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")


# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Luecken_data_Train = MLTraining.attach_celltype(Luecken_data_Train, Luecken_dataset_Train, consensus_field)
Luecken_data_Test  = MLTraining.attach_celltype(Luecken_data_Test,  Luecken_dataset_Test,  consensus_field)
Luecken_data_Cal   = MLTraining.attach_celltype(Luecken_data_Cal,   Luecken_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")


# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Luecken_data_Train = panel_intersection(rename_data_to_panel(Luecken_data_Train))
Luecken_data_Test  = panel_intersection(rename_data_to_panel(Luecken_data_Test))
Luecken_data_Cal   = panel_intersection(rename_data_to_panel(Luecken_data_Cal))

print("  ✓ Data columns now aligned to panel (panel order preserved)")


# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Luecken_data_Cal_lbl = Luecken_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Cal.columns]

Luecken_data_Train_Sub = Luecken_data_Train.drop(columns=drop_cols_train, errors="ignore")
Luecken_data_Test_Sub  = Luecken_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Luecken_data_Cal_Sub   = Luecken_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Luecken_data_Train_Sub.columns)
if list(Luecken_data_Test_Sub.columns) != cols_train or list(Luecken_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Luecken_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Luecken_data_Test_Sub,  "TEST")
MLTraining.check_finite(Luecken_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

class_names  = sorted(pd.Series(Luecken_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes")

s_cal = Luecken_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Luecken_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Luecken_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Luecken_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

X_cal_all_df = Luecken_data_Cal_Sub.copy()
X_te_all_df  = Luecken_data_Test_Sub.copy()
test_index   = Luecken_data_Test_Sub.index

P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

# Importances collectors
xgb_shap_rows = []       # mean_abs + corr (later filtered top10/class)
lr_contrib_rows = []     # LR base learner contributions (from stacker_raw)
platt_metrics_rows = []  # per-class logloss/brier pre vs post platt


# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Luecken/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Luecken_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Luecken_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN UMAP (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=PALETTE_BROAD,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Luecken_dataset_Train,
            train_df=Luecken_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=(str(celltype) == "Mature"),
        )

    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Luecken/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Luecken_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Luecken_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Luecken_data_Test["Celltype"].values == celltype).astype(int)

    # CAL split for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Luecken_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train
        )

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Luecken",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions: keep your existing helper for now if not moved
    # If you have moved this helper into MLTraining.py, replace call accordingly.
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)  # existing in notebook
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")


# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)


# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)
if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)


# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Luecken",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")


# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

# SHAP export (Top10/class; keep corr_feature_value_vs_shap)
shap_df = None
if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)

    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )

    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )

    keep_cols = [
        "depth", "class_name", "dataset",
        "feature", "mean_abs_shap", "corr_feature_value_vs_shap",
        "rank_within_class",
    ]
    shap_df = shap_df[keep_cols]

    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

# LR export
lr_df = None
if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")


# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Luecken_data_Test["Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]

    probs_dev_path = probs_dir / "probabilities_before_after_TEST.csv"
    probs_dev.to_csv(probs_dev_path, index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Luecken_data_Test["Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values

    release_probs_path = release_probs / "Multiclass_models_probabilities_on_test.csv"
    probs_release.to_csv(release_probs_path, index=True)


# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)

report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
report_df = pd.DataFrame(report).T

cm_mc_df = pd.DataFrame(
    cm_mc,
    index=pd.Index(class_names, name="true"),
    columns=pd.Index(class_names, name="pred"),
)

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")


# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")

if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []

y_pred_raw = P_te_raw.argmax(axis=1)
y_pred_cal = P_te_cal.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    if EXPORT_DEV:
        idx = pd.Index(["True=Other", f"True={cls}"], name="true")
        cols = pd.Index(["Pred=Other", f"Pred={cls}"], name="pred")
        pd.DataFrame(cm_raw, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_RAW.csv")
        pd.DataFrame(cm_cal, index=idx, columns=cols).to_csv(metrics_dir / f"{cls_safe}_binary_confmat_TEST_ARGMAX_CAL.csv")

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(
        y_true_bin,
        score_raw,
        f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)",
        dev_out,
        rel_out,
    )

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(
        y_true_bin,
        score_cal,
        f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)",
        dev_out,
        rel_out,
    )

    m_raw = _metrics_from_cm(cm_raw)
    m_raw.update(model="RAW", class_name=cls, auc=auc_raw)
    per_class_rows.append(m_raw)

    m_cal = _metrics_from_cm(cm_cal)
    m_cal.update(model="CAL", class_name=cls, auc=auc_cal)
    per_class_rows.append(m_cal)

if EXPORT_DEV:
    print(f"  ✓ Saved per-class plots (DEV) → {fig_percls}")
if EXPORT_RELEASE:
    print(f"  ✓ Saved per-class plots (RELEASE) → {release_single}")


# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    dev_metrics_path = metrics_dir / "per_class_argmax_metrics_TEST_included.csv"
    per_class_df.to_csv(dev_metrics_path, index=False)
    print(f"  ✓ Saved DEV per-class metrics → {dev_metrics_path}")

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)
    print(f"  ✓ Saved RELEASE per-class metrics → {out_single}")

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")
    print(f"  ✓ Appended DEV binary-head metrics → {dev_root / 'stacker_metrics.csv'}")

print("\n✅ BROAD PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Simplified"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_umap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Luecken_root = Path(models_output)

dev_root     = Luecken_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = Luecken_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Luecken_data_Train = MLTraining.attach_celltype(Luecken_data_Train, Luecken_dataset_Train, consensus_field)
Luecken_data_Test  = MLTraining.attach_celltype(Luecken_data_Test,  Luecken_dataset_Test,  consensus_field)
Luecken_data_Cal   = MLTraining.attach_celltype(Luecken_data_Cal,   Luecken_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Luecken_data_Train = panel_intersection(rename_data_to_panel(Luecken_data_Train))
Luecken_data_Test  = panel_intersection(rename_data_to_panel(Luecken_data_Test))
Luecken_data_Cal   = panel_intersection(rename_data_to_panel(Luecken_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Luecken_data_Cal_lbl = Luecken_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Cal.columns]

Luecken_data_Train_Sub = Luecken_data_Train.drop(columns=drop_cols_train, errors="ignore")
Luecken_data_Test_Sub  = Luecken_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Luecken_data_Cal_Sub   = Luecken_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Luecken_data_Train_Sub.columns)
if list(Luecken_data_Test_Sub.columns) != cols_train or list(Luecken_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Luecken_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Luecken_data_Test_Sub,  "TEST")
MLTraining.check_finite(Luecken_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Luecken_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Luecken_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Luecken_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Luecken_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Luecken_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Luecken_data_Cal_lbl_f  = Luecken_data_Cal_lbl.loc[cal_keep_mask].copy()
Luecken_data_Test_lbl_f = Luecken_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Luecken_data_Cal_Sub.loc[Luecken_data_Cal_lbl_f.index].copy()
X_te_all_df  = Luecken_data_Test_Sub.loc[Luecken_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Luecken_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Luecken_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Luecken_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Luecken_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Luecken/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Luecken_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Luecken_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Luecken_dataset_Train,
            train_df=Luecken_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Luecken/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Luecken_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Luecken_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Luecken_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Luecken_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Luecken",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Luecken",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Luecken_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Luecken_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ SIMPLIFIED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
# =============================================================================
# MODEL TRAINING PIPELINE (LEAN MAIN SCRIPT)
#   - RAW vs PLATT vs TEMP-SCALED
#   - DEV/RELEASE exports
#   - Importances: XGB SHAP mean_abs + corr (Top10) + LR meta-learner contributions
#   - Platt calibration plots (Ideal -> RAW -> Platt on top) with TEST LogLoss/Brier in legend
#   - Per-class pre/post Platt metrics exported to CSV
#   - Per-class TRAIN UMAP (pos vs rest) + legend PNG
#
# PATCHES ADDED (to address “plots missing / skipped” symptoms):
#   (A) Optional DEBUG_DIAGNOSTICS: prints output paths + CAL class balance + confirms file writes.
#   (B) Hard traceback on failures (instead of silent warnings) to surface root cause.
#   (C) SHAP beeswarm robustification: optional subsample of TRAIN to avoid memory/time failures.
#   (D) Optional SAFE_SINGLE_THREAD: mitigates fork/thread/numba/TBB instability during SHAP/plotting.
#   (E) Explicit existence checks after savefig (so “saved but not where expected” is obvious).
# =============================================================================

# =============================================================================
# SECTION 0: IMPORTS + CONFIG
# =============================================================================

from pathlib import Path
import joblib
import warnings
import traceback
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.metrics import (
    classification_report, confusion_matrix, ConfusionMatrixDisplay,
    roc_curve, auc,
)

import MLTraining  # uses MLTraining.py helpers

# -----------------------------------------------------------------------------
# Palettes
# -----------------------------------------------------------------------------

PALETTE_BROAD = {"Immature": "#0079ea", "Mature": "#AF3434"}

PALETTE_SIMPLIFIED = {
    "HSPC":      "#0079ea",
    "Erythroid": "#c11212",
    "pDC":       "#62E6B8",
    "Monocyte":  "#D27CE3",
    "Myeloid":   "#8D43CD",
    "CD4_T":     "#C99546",
    "CD8_T":     "#6B3317",
    "B":         "#68D827",
    "cDC":       "#16D2E3",
    "Other_T":   "#EDB416",
    "NK":        "#FBEF0D",
}

PALETTE_DETAILED = {
    "HSC_MPP":            "#0079ea",
    "LMPP":               "#17BECF",
    "GMP":                "#C5E4FF",
    "Myeloid progenitor": "#AEC7E8",
    "Monocyte":           "#D27CE3",
    "CD14 Mono":          "#D27CE3",
    "CD16 Mono":          "#8D43CD",
    "Erythroblast":       "#F30A1A",
    "ErP":                "#D1235A",
    "MEP":                "#E364B0",
    "CD4 T Naive":        "#C99546",
    "CD4 T Memory":       "#C1AF93",
    "CD8 T Naive":        "#4D382E",
    "CD8 T Memory":       "#6B3317",
    "Other_T":            "#EDB416",
    "Treg":               "#6E6C37",
    "B Naive":            "#1C511D",
    "B Memory":           "#68D827",
    "Pro-B":              "#66BB6A",
    "Pre-B":              "#2DBD67",
    "Immature B":         "#91FF7B",
    "Plasma":             "#9DC012",
    "cDC1":               "#76A7CB",
    "cDC2":               "#16D2E3",
    "pDC":                "#69FFCB",
    "NK CD56 bright":     "#F3AC1F",
    "NK CD56 dim":        "#FBEF0D",
}

PALETTE_BY_DEPTH = {
    "Broad": PALETTE_BROAD,
    "Simplified": PALETTE_SIMPLIFIED,
    "Detailed": PALETTE_DETAILED,
}

# -----------------------------------------------------------------------------
# OPTIONAL: SHAP dependency
# -----------------------------------------------------------------------------
try:
    import shap  # noqa: F401
    HAS_SHAP = True
except Exception:
    HAS_SHAP = False

# -----------------------------------------------------------------------------
# EXPORT SWITCHES
# -----------------------------------------------------------------------------
EXPORT_RELEASE = True
EXPORT_DEV     = False

# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
name_target_class = "Detailed"  # "Broad" | "Simplified" | "Detailed"
EXCLUDE_CLASSES = {}

custom_palette = PALETTE_BY_DEPTH.get(name_target_class, {})
kf          = MLTraining.CV
num_cores   = -1
metrics_log = []

# -----------------------------------------------------------------------------
# DIAGNOSTICS / ROBUSTIFICATION SWITCHES (PATCH)
# -----------------------------------------------------------------------------
DEBUG_DIAGNOSTICS = True
HARD_TRACEBACKS   = True   # if True: prints stack traces when plot/SHAP fails
SHAP_TRAIN_SUBSAMPLE_MAX_N = 5000  # set None to disable subsampling
SAFE_SINGLE_THREAD = False  # set True if you see Numba/TBB fork/thread warnings

# -----------------------------------------------------------------------------
# EMBEDDING CONFIG (for Class_Train_data.png)
# -----------------------------------------------------------------------------
EMBEDDING_SOURCE = "adata_obsm"   # "adata_obsm" | "adata_obs" | "train_df"
EMBEDDING_OBSM_KEY = "X_umap"
EMBEDDING_OBS_X = "UMAP_1"
EMBEDDING_OBS_Y = "UMAP_2"
EMBEDDING_DF_X = "UMAP_1"
EMBEDDING_DF_Y = "UMAP_2"

# -----------------------------------------------------------------------------
# ROOTS
# -----------------------------------------------------------------------------
Luecken_root = Path(models_output)

dev_root     = Luecken_root / "Dev"
models_root  = dev_root / name_target_class / "Models"  / name_target_class
reports_root = dev_root / name_target_class / "Reports" / name_target_class
fig_root     = dev_root / name_target_class / "Figures" / name_target_class

heads_dir       = models_root / "heads"
metrics_dir     = reports_root / "metrics"
probs_dir       = reports_root / "probabilities"
fig_percls      = fig_root / "per_class"
dev_importances = reports_root / "Importances"

release_root    = Luecken_root / "Release"
release_models  = release_root / name_target_class / "Models"
release_reports = release_root / name_target_class / "Reports"
release_metrics = release_reports / "Metrics"
release_probs   = release_reports / "Probabilities"
release_imps    = release_reports / "Importances"
release_figs    = release_root / name_target_class / "Figures"
release_single  = release_figs / "Single_classes"

if EXPORT_DEV:
    for p in (models_root, heads_dir, reports_root, metrics_dir, probs_dir, fig_root, fig_percls, dev_importances):
        p.mkdir(parents=True, exist_ok=True)

if EXPORT_RELEASE:
    for p in (release_models, release_reports, release_metrics, release_probs, release_imps, release_figs, release_single):
        p.mkdir(parents=True, exist_ok=True)
    print(f"[INFO] RELEASE Root:    {release_root}")
    print(f"[INFO] RELEASE Models:  {release_models}")
    print(f"[INFO] RELEASE Reports: {release_reports}")
    print(f"[INFO] RELEASE Figures: {release_figs}")

if DEBUG_DIAGNOSTICS:
    print(f"[DEBUG] HAS_SHAP={HAS_SHAP} EXPORT_RELEASE={EXPORT_RELEASE} EXPORT_DEV={EXPORT_DEV}")
    print(f"[DEBUG] release_single={release_single}")
    print(f"[DEBUG] release_imps={release_imps}")
    print(f"[DEBUG] SAFE_SINGLE_THREAD={SAFE_SINGLE_THREAD} SHAP_SUBSAMPLE_MAX_N={SHAP_TRAIN_SUBSAMPLE_MAX_N}")

# =============================================================================
# SECTION 1: ATTACH CELL-TYPE LABELS
# =============================================================================
print("\n[STEP 1] Attaching cell-type labels from AnnData.obs...")

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"
Luecken_data_Train = MLTraining.attach_celltype(Luecken_data_Train, Luecken_dataset_Train, consensus_field)
Luecken_data_Test  = MLTraining.attach_celltype(Luecken_data_Test,  Luecken_dataset_Test,  consensus_field)
Luecken_data_Cal   = MLTraining.attach_celltype(Luecken_data_Cal,   Luecken_dataset_Cal,   consensus_field)

print(f"  ✓ Attached '{consensus_field}' to Train/Test/Cal splits")

# =============================================================================
# SECTION 2: ALIGN DATA COLUMNS TO REFERENCE PANEL
# =============================================================================
print("\n[STEP 2] Aligning data columns to reference panel (exact names preserved)...")

panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))
panel_keys = MLTraining.norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Adjust MLTraining.norm_feats rules.")

def rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = MLTraining.norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new)
            safe_map[old] = new

    if drops:
        print(f"  [WARN] Dropping {len(drops)} duplicated-mapped columns (sample: {drops[:5]})")
        df.drop(columns=drops, inplace=True, errors="ignore")

    df.rename(columns=safe_map, inplace=True)
    print(f"  ✓ Matched {len(safe_map)}/{len(feat)} data columns to panel")
    return df

def panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Luecken_data_Train = panel_intersection(rename_data_to_panel(Luecken_data_Train))
Luecken_data_Test  = panel_intersection(rename_data_to_panel(Luecken_data_Test))
Luecken_data_Cal   = panel_intersection(rename_data_to_panel(Luecken_data_Cal))
print("  ✓ Data columns now aligned to panel (panel order preserved)")

# =============================================================================
# SECTION 3: PREPARE FEATURES & LABELS (WITH CAL/TEST ROW FILTERING)
# =============================================================================
print("\n[STEP 3] Extracting features and labels...")

Luecken_data_Cal_lbl = Luecken_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Cal.columns]

Luecken_data_Train_Sub = Luecken_data_Train.drop(columns=drop_cols_train, errors="ignore")
Luecken_data_Test_Sub  = Luecken_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Luecken_data_Cal_Sub   = Luecken_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

cols_train = list(Luecken_data_Train_Sub.columns)
if list(Luecken_data_Test_Sub.columns) != cols_train or list(Luecken_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

MLTraining.check_finite(Luecken_data_Train_Sub, "TRAIN")
MLTraining.check_finite(Luecken_data_Test_Sub,  "TEST")
MLTraining.check_finite(Luecken_data_Cal_Sub,   "CAL")

print(f"  ✓ Using {len(cols_train)} panel-intersected features (exact panel names)")
print(f"    Sample: {cols_train[:5]}...")

# classes learned from TRAIN, excluding user-specified
all_classes = sorted(pd.Series(Luecken_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if str(c) not in EXCLUDE_CLASSES]

excluded_present = sorted(set(all_classes).intersection(EXCLUDE_CLASSES))
if excluded_present:
    print(f"  [INFO] Excluding {len(excluded_present)} classes: {excluded_present}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}
print(f"  ✓ Found {K} classes after exclusions")

# ---- critical: filter CAL/TEST rows to those classes ----
keep_set = set(map(str, class_names))

cal_keep_mask  = Luecken_data_Cal_lbl["Celltype"].astype(str).isin(keep_set)
test_keep_mask = Luecken_data_Test["Celltype"].astype(str).isin(keep_set)

n_cal_drop  = int((~cal_keep_mask).sum())
n_test_drop = int((~test_keep_mask).sum())

if n_cal_drop > 0:
    dropped = sorted(Luecken_data_Cal_lbl.loc[~cal_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_cal_drop} CAL rows with excluded/unknown labels: {dropped}")

if n_test_drop > 0:
    dropped = sorted(Luecken_data_Test.loc[~test_keep_mask, "Celltype"].astype(str).unique().tolist())
    print(f"  [INFO] Dropping {n_test_drop} TEST rows with excluded/unknown labels: {dropped}")

# filtered label frames
Luecken_data_Cal_lbl_f  = Luecken_data_Cal_lbl.loc[cal_keep_mask].copy()
Luecken_data_Test_lbl_f = Luecken_data_Test.loc[test_keep_mask, ["Celltype"]].copy()

# filtered feature frames (must align by index)
X_cal_all_df = Luecken_data_Cal_Sub.loc[Luecken_data_Cal_lbl_f.index].copy()
X_te_all_df  = Luecken_data_Test_Sub.loc[Luecken_data_Test_lbl_f.index].copy()
test_index   = X_te_all_df.index

# map filtered labels
s_cal = Luecken_data_Cal_lbl_f["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Luecken_data_Cal_lbl_f.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in CAL after filtering: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Luecken_data_Test_lbl_f["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Luecken_data_Test_lbl_f.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Still-unmapped labels in TEST after filtering: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# probability matrices sized to filtered CAL/TEST
P_cal_raw   = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_cal_platt = np.zeros((X_cal_all_df.shape[0], K), dtype=float)

P_te_raw    = np.zeros((X_te_all_df.shape[0],  K), dtype=float)
P_te_platt  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

heads_mem = {}

xgb_shap_rows      = []
lr_contrib_rows    = []
platt_metrics_rows = []

# =============================================================================
# SECTION 4: TRAIN OvR BINARY HEADS (+ Platt on CAL)
# =============================================================================
print(f"\n[STEP 4] Training {K} binary OvR classifiers...\n")

TOP_N = 10
base_order = ["NB", "XGB", "KNN", "MLP"]

for celltype in class_names:
    k = class_to_idx[celltype]
    cls_safe = MLTraining.safe_name(celltype)
    print(f"▸ Processing {cls_safe} (class {k+1}/{K})")

    # 4.1 Load TRAIN barcodes for this class
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Luecken/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_training_class_{cls_safe}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Luecken_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Luecken_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"  [SKIP] Empty or single-class train (pos={y_tr.sum()}, neg={len(y_tr)-y_tr.sum()})\n")
        continue

    # 4.1b TRAIN embedding (pos vs rest) + legend
    try:
        MLTraining.save_class_train_umap_pngs(
            celltype=str(celltype),
            cls_safe=cls_safe,
            barcodes=found_train_barcodes,
            y_bin=y_tr,
            custom_palette=custom_palette,
            out_dir_dev=fig_percls if EXPORT_DEV else None,
            out_dir_rel=release_single if EXPORT_RELEASE else None,
            adata_train=Luecken_dataset_Train,
            train_df=Luecken_data_Train,
            embedding_source=EMBEDDING_SOURCE,
            obsm_key=EMBEDDING_OBSM_KEY,
            obs_x=EMBEDDING_OBS_X,
            obs_y=EMBEDDING_OBS_Y,
            df_x=EMBEDDING_DF_X,
            df_y=EMBEDDING_DF_Y,
            neg_color="#A3A3A3",
            outline=(5, 0.05),
            debug=False,
        )
    except Exception as e:
        warnings.warn(f"UMAP train plot failed for '{celltype}': {e}")

    # 4.2 Load TEST barcodes for class-specific metrics (optional)
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Luecken/Consensus_annotation_{name_target_class.lower()}_final/Barcodes_testing_class_{cls_safe}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Luecken_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Luecken_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # Full TEST (filtered) for head probabilities / calibration plot eval
    X_te_all_local = X_te_all_df
    y_te_all = (Luecken_data_Test.loc[X_te_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # CAL split (filtered) for Platt fitting
    X_cal_df  = X_cal_all_df
    y_cal_bin = (Luecken_data_Cal.loc[X_cal_all_df.index, "Celltype"].astype(str).values == str(celltype)).astype(int)

    # 4.3 Fit scaler on TRAIN; transform all splits
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df: pd.DataFrame) -> pd.DataFrame:
        return pd.DataFrame(scaler.transform(df.values), index=df.index, columns=cols_train)

    X_tr_sc_df      = _sc(X_tr_df)
    X_te_sc_df      = _sc(X_te_df)
    X_te_all_sc_df  = _sc(X_te_all_local)
    X_cal_sc_df     = _sc(X_cal_df)

    # 4.4 Train base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=cls_safe)

    # 4.5 Stacking RAW head
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # 4.6 Platt calibration (fit on CAL only)
    pos_cal   = int(y_cal_bin.sum())
    n_cal_bin = int(len(y_cal_bin))
    has_both  = (0 < pos_cal < n_cal_bin)

    stacker_platt = None
    if has_both:
        stacker_platt = MLTraining.calibrate_prefit(stacker_raw, X_cal_sc_df, y_cal_bin, method="sigmoid")
    else:
        print("    [WARN] Skipped Platt calibration (single-class CAL)")

    # 4.7 Platt evaluation curve on TEST (Ideal -> RAW -> Platt) + metrics row
    try:
        p_test_raw   = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        p_test_platt = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1] if stacker_platt is not None else None

        dev_platt = (fig_percls / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_DEV else None
        rel_platt = (release_single / f"{cls_safe}_Platt_calibration_evaluation_on_test.png") if EXPORT_RELEASE else None

        ll_raw, br_raw, ll_pl, br_pl, pl_avail = MLTraining.plot_platt_calibration_on_test(
            y_true_bin=y_te_all.astype(int),
            p_raw=p_test_raw,
            p_platt=p_test_platt,
            title=f"{name_target_class} – {celltype}: Platt calibration evaluation on TEST",
            out_png_dev=dev_platt,
            out_png_rel=rel_platt,
            n_bins=15,
        )

        platt_metrics_rows.append({
            "depth": name_target_class,
            "class_name": str(celltype),
            "n_test_samples": int(len(y_te_all)),
            "n_test_positive": int(y_te_all.sum()),
            "logloss_raw": ll_raw,
            "brier_raw": br_raw,
            "logloss_platt": ll_pl,
            "brier_platt": br_pl,
            "platt_available": bool(pl_avail),
        })

    except Exception as e:
        warnings.warn(f"Platt calibration plot failed for class '{celltype}': {e}")

    # 4.8 Save per-class head bundle + keep in-memory for package
    head_bundle = {
        "atlas": "Luecken",
        "depth": name_target_class,
        "label": str(celltype),
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
        "columns": cols_train,
        "scaler": scaler,
        "model_raw": stacker_raw,
        "model_platt": stacker_platt,
    }
    heads_mem[str(celltype)] = head_bundle

    if EXPORT_DEV:
        joblib.dump(head_bundle, heads_dir / f"{cls_safe}.joblib")

    # 4.9 Optional per-head metrics logging (class-specific TEST subset)
    try:
        model_for_eval = stacker_platt if stacker_platt is not None else stacker_raw
        m = MLTraining.evaluate_classifier(model_for_eval, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=str(celltype), used_platt=bool(stacker_platt is not None))
        metrics_log.append(m)
    except Exception:
        pass

    # 4.10 OvR probability matrices (RAW + PLATT) for multiclass downstream
    P_cal_raw[:, k] = stacker_raw.predict_proba(X_cal_sc_df)[:, 1]
    P_te_raw[:,  k] = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]

    if stacker_platt is not None:
        P_cal_platt[:, k] = stacker_platt.predict_proba(X_cal_sc_df)[:, 1]
        P_te_platt[:,  k] = stacker_platt.predict_proba(X_te_all_sc_df)[:, 1]
    else:
        P_cal_platt[:, k] = P_cal_raw[:, k]
        P_te_platt[:,  k] = P_te_raw[:,  k]

    # 4.11 SHAP: mean_abs + corr on TEST; beeswarm TRAIN only
    if HAS_SHAP:
        try:
            plt.figure(figsize=(6, 6))
            shap_sum_test = MLTraining.xgb_shap_mean_abs_and_corr(XGB_model, X_te_all_sc_df, class_index=1)
            shap_sum_test["depth"] = name_target_class
            shap_sum_test["class_name"] = str(celltype)
            shap_sum_test["dataset"] = "TEST"
            xgb_shap_rows.extend(shap_sum_test.to_dict(orient="records"))

            # Beeswarm on TRAIN only
            if EXPORT_DEV:
                outp = fig_percls / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )
            if EXPORT_RELEASE:
                outp = release_single / f"{cls_safe}_SHAP_beeswarm_TRAIN.png"
                sv, _ = MLTraining.xgb_shap_values(XGB_model, X_tr_sc_df, class_index=1)
                MLTraining.plot_xgb_shap_beeswarm(
                    sv, X_tr_sc_df,
                    title=f"{name_target_class} – {celltype}: SHAP importance",
                    max_display=TOP_N,
                    out_path=outp,
                    figsize=(6, 7),
                )

        except Exception as e:
            warnings.warn(f"SHAP failed for class '{celltype}': {e}")

    # 4.12 LR meta-learner contributions (unchanged)
    try:
        contrib = _lr_baselearner_contributions(stacker_raw, X_te_all_sc_df, base_order=base_order)
        row = {
            "depth": name_target_class,
            "class_name": str(celltype),
            "dataset": "TEST",
            "n_meta_features": contrib["n_meta_features"],
            "per_estimator_meta_cols": contrib["per_estimator_meta_cols"],
        }
        for b in base_order:
            row[f"{b}_mean_abs_contribution"] = contrib["per_base"].get(b, {}).get("mean_abs_contribution", 0.0)
            row[f"{b}_coef_l1"]               = contrib["per_base"].get(b, {}).get("coef_l1", 0.0)
            row[f"{b}_n_meta_cols"]           = contrib["per_base"].get(b, {}).get("n_cols", 0)
        lr_contrib_rows.append(row)
    except Exception as e:
        warnings.warn(f"LR contribution extraction failed for class '{celltype}': {e}")

    print("")

# =============================================================================
# EXPORT: Per-class LogLoss & Brier (pre vs post Platt) on TEST
# =============================================================================
print("\n[EXPORT] Per-class calibration metrics (RAW vs Platt on TEST)...")

_ = MLTraining.export_platt_metrics_csv(
    platt_metrics_rows,
    out_dev=metrics_dir if EXPORT_DEV else None,
    out_rel=release_metrics if EXPORT_RELEASE else None,
    filename="Single_classes_metrics_pre_and_post_platt_calibration.csv",
)

# =============================================================================
# SECTION 5: MULTICLASS TEMPERATURE SCALING (fit on CAL using PLATT matrix)
# =============================================================================
print("\n[STEP 5] Multiclass Temperature Scaling on CAL (using Platt OvR probabilities)...")

def _check_probs(P: np.ndarray, name: str):
    if np.isnan(P).any() or np.isinf(P).any():
        raise ValueError(f"{name} contains NaN/Inf")
    if (P < 0).any() or (P > 1).any():
        raise ValueError(f"{name} contains values outside [0,1]")

_check_probs(P_cal_platt, "P_cal_platt")
_check_probs(P_te_platt,  "P_te_platt")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal_platt, y_cal_multiclass)
P_te_cal = ts_cal.transform(P_te_platt)

P_te_cal = np.asarray(P_te_cal)
if P_te_cal.ndim == 1:
    P_te_cal = P_te_cal.reshape(-1, 1)

if P_te_cal.shape[1] == 1 and K == 2:
    P_te_cal = np.hstack([1.0 - P_te_cal, P_te_cal])
elif P_te_cal.shape[1] != K:
    row_sums = P_te_platt.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_cal = P_te_platt / row_sums
    print(f"  [WARN] TemperatureScaling returned shape {P_te_cal.shape}; fell back to sum-normalized OvR probs")

if EXPORT_DEV:
    joblib.dump(ts_cal, models_root / "temp_scaler.joblib")
    pd.Series(class_names, name="class_name").to_csv(models_root / "class_names.csv", index=False)

# =============================================================================
# SECTION 5b: SAVE DEPLOYABLE PACKAGE(S)
# =============================================================================
print("\n[STEP 5b] Saving deployable package(s)...")

package = {
    "atlas": "Luecken",
    "depth": name_target_class,
    "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    "class_names": class_names,
    "heads": heads_mem,
    "temp_scaler": ts_cal,
}

if EXPORT_DEV:
    joblib.dump(package, models_root / "package.joblib")

if EXPORT_RELEASE:
    joblib.dump(package, release_models / "Multiclass_models.joblib")

# =============================================================================
# SECTION 5c: EXPORT IMPORTANCES (Top10 per class)
# =============================================================================
print("\n[STEP 5c] Exporting importances (Top 10 per class; SHAP mean_abs + corr + LR)...")

if len(xgb_shap_rows) > 0:
    shap_df = pd.DataFrame(xgb_shap_rows)
    shap_df = (
        shap_df.sort_values(["depth", "class_name", "mean_abs_shap"], ascending=[True, True, False])
               .groupby(["depth", "class_name"], as_index=False)
               .head(TOP_N)
    )
    shap_df["rank_within_class"] = (
        shap_df.groupby(["depth", "class_name"])["mean_abs_shap"]
               .rank(ascending=False, method="first")
               .astype(int)
    )
    shap_df = shap_df[
        ["depth", "class_name", "dataset", "feature", "mean_abs_shap", "corr_feature_value_vs_shap", "rank_within_class"]
    ]
    if EXPORT_DEV:
        shap_df.to_csv(dev_importances / "SHAP_XGB_Feature_importances.csv", index=False)
    if EXPORT_RELEASE:
        shap_df.to_csv(release_imps / "SHAP_XGB_Feature_importances.csv", index=False)
else:
    print("  [INFO] No SHAP rows collected (or SHAP not installed).")

if len(lr_contrib_rows) > 0:
    lr_df = pd.DataFrame(lr_contrib_rows)
    if EXPORT_DEV:
        lr_df.to_csv(dev_importances / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
    if EXPORT_RELEASE:
        lr_df.to_csv(release_imps / "LR_MetaLearner_BaseLearner_contributions.csv", index=False)
else:
    print("  [INFO] No LR contribution rows collected.")

# =============================================================================
# SECTION 6: SAVE PROBABILITIES
# =============================================================================
print("\n[STEP 6] Saving probability outputs...")

if EXPORT_DEV:
    probs_raw_df   = pd.DataFrame(P_te_raw,   index=test_index, columns=[f"raw_{c}"   for c in class_names])
    probs_platt_df = pd.DataFrame(P_te_platt, index=test_index, columns=[f"platt_{c}" for c in class_names])
    probs_cal_df   = pd.DataFrame(P_te_cal,   index=test_index, columns=[f"cal_{c}"   for c in class_names])

    probs_dev = pd.concat([probs_raw_df, probs_platt_df, probs_cal_df], axis=1)
    probs_dev["true_label"] = Luecken_data_Test.loc[test_index, "Celltype"].values
    probs_dev["pred_raw"]   = P_te_raw.argmax(axis=1)
    probs_dev["pred_cal"]   = P_te_cal.argmax(axis=1)
    probs_dev["pred_raw_name"] = [class_names[i] for i in probs_dev["pred_raw"].values]
    probs_dev["pred_cal_name"] = [class_names[i] for i in probs_dev["pred_cal"].values]
    probs_dev.to_csv(probs_dir / "probabilities_before_after_TEST.csv", index=True)

if EXPORT_RELEASE:
    probs_cal_df = pd.DataFrame(P_te_cal, index=test_index, columns=[f"cal_{c}" for c in class_names])
    probs_release = probs_cal_df.copy()
    probs_release["true_label"]    = Luecken_data_Test.loc[test_index, "Celltype"].values
    probs_release["pred_cal"]      = P_te_cal.argmax(axis=1)
    probs_release["pred_cal_name"] = [class_names[i] for i in probs_release["pred_cal"].values]
    probs_release["max_cal_prob"]  = probs_cal_df.max(axis=1).values
    probs_release.to_csv(release_probs / "Multiclass_models_probabilities_on_test.csv", index=True)

# =============================================================================
# SECTION 7: MULTICLASS EVALUATION (TEST) — using CAL probabilities
# =============================================================================
print("\n[STEP 7] Multiclass evaluation (TEST; using CAL probs)...\n")

y_pred_cal = P_te_cal.argmax(axis=1)
report_txt = classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, digits=3)
print("Multiclass Classification Report (TEST):")
print(report_txt)

cm_mc = confusion_matrix(y_test_multiclass, y_pred_cal, labels=range(K))
print("\nConfusion Matrix (rows=true, cols=pred):")
print(cm_mc)

report_df = pd.DataFrame(
    classification_report(y_test_multiclass, y_pred_cal, target_names=class_names, output_dict=True)
).T

cm_mc_df = pd.DataFrame(cm_mc, index=pd.Index(class_names, name="true"), columns=pd.Index(class_names, name="pred"))

if EXPORT_DEV:
    report_df.to_csv(metrics_dir / "multiclass_classification_report_TEST.csv")
    cm_mc_df.to_csv(metrics_dir / "multiclass_confusion_matrix_TEST.csv")

if EXPORT_RELEASE:
    report_df.to_csv(release_metrics / "Multiclass_models_metrics_on_test.csv")
    cm_mc_df.to_csv(release_metrics / "Multiclass_models_confusion_matrix_on_test.csv")

# =============================================================================
# SECTION 8: FIGURES (MULTICLASS CM + PER-CLASS CONF & ROC)
# =============================================================================
print("\n[STEP 8] Saving plots...")

def _save_multiclass_cm_png(out_path: Path):
    fig = plt.figure(figsize=(7, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm_mc, display_labels=class_names)
    disp.plot(values_format="d", cmap="Blues", colorbar=False)
    plt.title(f"{name_target_class} – Multiclass Confusion Matrix (on TEST)")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)

if EXPORT_DEV:
    _save_multiclass_cm_png(fig_root / "multiclass_confusion_matrix_TEST.png")
if EXPORT_RELEASE:
    _save_multiclass_cm_png(release_figs / "Multiclass_models_confusion_matrix_on_test.png")

per_class_rows = []
y_pred_raw = P_te_raw.argmax(axis=1)

def _metrics_from_cm(cm2x2):
    tn, fp, fn, tp = cm2x2.ravel()
    support = tp + fn
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2 * prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0
    return dict(TP=int(tp), FP=int(fp), TN=int(tn), FN=int(fn),
                support=int(support), precision=prec, recall=rec, f1=f1)

def _save_cm_fig(cm2x2, cls_label, title, out_dev: Path | None, out_rel: Path | None):
    fig = plt.figure(figsize=(5.5, 5.0))
    ConfusionMatrixDisplay(confusion_matrix=cm2x2, display_labels=["Other", cls_label]).plot(
        values_format="d", cmap="Blues", colorbar=False
    )
    plt.title(title)
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)

def _save_roc(y_true, y_score, title, out_dev: Path | None, out_rel: Path | None):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    a = auc(fpr, tpr)
    fig = plt.figure(figsize=(6.0, 5.5))
    plt.plot([0, 1], [0, 1], linestyle="--", linewidth=1, color="gray")
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"{title} AUC={a:.3f}")
    plt.tight_layout()
    if out_dev is not None:
        plt.savefig(out_dev, dpi=300)
    if out_rel is not None:
        plt.savefig(out_rel, dpi=300)
    plt.close(fig)
    return a

for k, cls in enumerate(class_names):
    cls_safe = MLTraining.safe_name(cls)
    y_true_bin = (y_test_multiclass == k).astype(int)

    score_raw = P_te_raw[:, k]
    score_cal = P_te_cal[:, k]

    y_pred_raw_bin = (y_pred_raw == k).astype(int)
    y_pred_cal_bin = (y_pred_cal == k).astype(int)

    cm_raw = confusion_matrix(y_true_bin, y_pred_raw_bin, labels=[0, 1])
    cm_cal = confusion_matrix(y_true_bin, y_pred_cal_bin, labels=[0, 1])

    dev_out = (fig_percls / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_raw, cls, f"{name_target_class} – {cls}: Confusion Matrix (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_confusion_matrix_on_test.png") if EXPORT_RELEASE else None
    _save_cm_fig(cm_cal, cls, f"{name_target_class} – {cls}: Confusion Matrix (CAL; Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_RAW_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_raw = _save_roc(y_true_bin, score_raw, f"{name_target_class} – {cls}: ROC (RAW; pre-Platt & Temp)", dev_out, rel_out)

    dev_out = (fig_percls / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_DEV else None
    rel_out = (release_single / f"{cls_safe}_CAL_ROC_on_test.png") if EXPORT_RELEASE else None
    auc_cal = _save_roc(y_true_bin, score_cal, f"{name_target_class} – {cls}: ROC (CAL; Platt & Temp)", dev_out, rel_out)

    m_raw = _metrics_from_cm(cm_raw); m_raw.update(model="RAW", class_name=cls, auc=auc_raw); per_class_rows.append(m_raw)
    m_cal = _metrics_from_cm(cm_cal); m_cal.update(model="CAL", class_name=cls, auc=auc_cal); per_class_rows.append(m_cal)

# =============================================================================
# SECTION 9: SAVE METRICS TABLES
# =============================================================================
print("\n[STEP 9] Saving metrics tables...")

per_class_df = pd.DataFrame(per_class_rows)[
    ["class_name", "model", "TP", "FP", "TN", "FN", "support", "precision", "recall", "f1", "auc"]
].sort_values(["class_name", "model"])

if EXPORT_DEV:
    per_class_df.to_csv(metrics_dir / "per_class_argmax_metrics_TEST_included.csv", index=False)

if EXPORT_RELEASE:
    out_single = release_metrics / "Single_classes_metrics_and_confusion_matrix_on_test.csv"
    per_class_df.to_csv(out_single, index=False)

if EXPORT_DEV:
    metrics_df = pd.DataFrame.from_records(metrics_log)
    MLTraining.append_metrics_csv(metrics_df, csv_path=dev_root / "stacker_metrics.csv")

print("\n✅ DETAILED PIPELINE COMPLETE. Exports saved according to EXPORT_DEV / EXPORT_RELEASE.\n")


# Additional metrics

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# ---------- INPUT ----------
cm_csv_path = Path(
    "/Users/kgurashi/GitHub/2024__EspressoPro_Manuscript/Data/Pre_trained_models/"
    "TotalSeqD_Heme_Oncology_CAT399906/Release/Hao/Reports/Detailed/metrics/"
    "multiclass_confusion_matrix_TEST_included.csv"
)

name_target_class = "Simplified"  # Broad | Simplified | Detailed

NORMALIZE  = "row"   # None | "row" | "col"
CBAR_LABEL = "% Agreement"

authors_to_do = ["Hao", "Luecken", "Zhang", "Triana"]

def render_heatmap(csv_path: Path, normalize: str = "row", cbar_label: str = "% Agreement"):
    cm_df = pd.read_csv(csv_path, index_col=0)
    labels_true = list(cm_df.index)
    labels_pred = list(cm_df.columns)
    cm = cm_df.values.astype(float)

    # ---------- NORMALISE ----------
    if normalize is None:
        data = cm
    else:
        if normalize == "row":
            denom = cm.sum(axis=1, keepdims=True)
        elif normalize == "col":
            denom = cm.sum(axis=0, keepdims=True)
        else:
            raise ValueError("normalize must be None | 'row' | 'col'")
        denom[denom == 0] = 1.0
        data = (cm / denom) * 100.0

    # ---------- OUTPUT PATHS ----------
    # csv_path: .../Release/<Author>/Reports/<Depth>/metrics/<file.csv>
    depth_root = csv_path.parents[2]      # .../Release/<Author>/Reports/<Depth>
    figures_dir = depth_root / "Figures"  # .../Release/<Author>/Reports/<Depth>/Figures
    figures_dir.mkdir(parents=True, exist_ok=True)

    out_base = figures_dir / "Multiclass_models_confusion_matrix_on_test_with_percentage_agreement"
    png_path = out_base.with_suffix(".png")
    pdf_path = out_base.with_suffix(".pdf")

    # ---------- PLOT ----------
    n_cols = len(labels_pred)
    n_rows = len(labels_true)

    fig_w = max(5, 0.6 * n_cols + 2)
    fig_h = max(5, 0.6 * n_rows + 2)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h), constrained_layout=False)

    im = ax.imshow(data, cmap="magma", aspect="auto", interpolation="nearest")

    ax.set_xticks(np.arange(n_cols))
    ax.set_yticks(np.arange(n_rows))
    ax.set_xticklabels(labels_pred, rotation=45, ha="right", fontsize=20)
    ax.set_yticklabels(labels_true, fontsize=20)
    ax.tick_params(axis="both", which="major", labelsize=20)

    ax.set_xlim(-0.5, n_cols - 0.5)
    ax.set_ylim(n_rows - 0.5, -0.5)

    ax.set_xlabel("")
    ax.set_ylabel("")

    ax.set_xticks(np.arange(-0.5, n_cols, 1.0), minor=True)
    ax.set_yticks(np.arange(-0.5, n_rows, 1.0), minor=True)
    ax.grid(which="minor", color="#373737", linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2.5%", pad=0.02)
    cb = plt.colorbar(im, cax=cax)
    cb.set_label(cbar_label, fontsize=20)
    cb.ax.tick_params(labelsize=20, length=3, width=0.6)
    for spine in cax.spines.values():
        spine.set_linewidth(0.6)

    for spine in ax.spines.values():
        spine.set_edgecolor("#000000")
        spine.set_linewidth(3)

    plt.tight_layout()
    plt.savefig(png_path, dpi=300)
    plt.savefig(pdf_path)
    plt.close(fig)

    print(f"[SAVE] {png_path}")
    print(f"[SAVE] {pdf_path}")

# Infer assay root
# cm_csv_path: .../<Assay>/Release/Hao/Reports/Detailed/metrics/<file.csv>
# assay_root should be: .../<Assay>/Release
assay_root = cm_csv_path.parents[5]

for author in authors_to_do:
    csv = (
        assay_root
        / author
        / "Release"
        / name_target_class
        / "Reports"
        / "metrics"
        / "Multiclass_models_confusion_matrix_on_test.csv"
    )
    if not csv.exists():
        print(f"[SKIP] Missing: {csv}")
        continue
    render_heatmap(csv, normalize=NORMALIZE, cbar_label=CBAR_LABEL)


In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# ---------- INPUT ----------
cm_csv_path = Path(
    "/Users/kgurashi/GitHub/2024__EspressoPro_Manuscript/Data/Pre_trained_models/"
    "TotalSeqD_Heme_Oncology_CAT399906/Release/Hao/Reports/Detailed/metrics/"
    "multiclass_confusion_matrix_TEST_included.csv"
)

name_target_class = "Detailed"  # Broad | Simplified | Detailed

NORMALIZE  = "row"   # None | "row" | "col"
CBAR_LABEL = "% Agreement"

authors_to_do = ["Hao", "Luecken", "Zhang", "Triana"]

def render_heatmap(csv_path: Path, normalize: str = "row", cbar_label: str = "% Agreement"):
    cm_df = pd.read_csv(csv_path, index_col=0)
    labels_true = list(cm_df.index)
    labels_pred = list(cm_df.columns)
    cm = cm_df.values.astype(float)

    # ---------- NORMALISE ----------
    if normalize is None:
        data = cm
    else:
        if normalize == "row":
            denom = cm.sum(axis=1, keepdims=True)
        elif normalize == "col":
            denom = cm.sum(axis=0, keepdims=True)
        else:
            raise ValueError("normalize must be None | 'row' | 'col'")
        denom[denom == 0] = 1.0
        data = (cm / denom) * 100.0

    # ---------- OUTPUT PATHS ----------
    # csv_path: .../Release/<Author>/Reports/<Depth>/metrics/<file.csv>
    depth_root = csv_path.parents[2]      # .../Release/<Author>/Reports/<Depth>
    figures_dir = depth_root / "Figures"  # .../Release/<Author>/Reports/<Depth>/Figures
    figures_dir.mkdir(parents=True, exist_ok=True)

    out_base = figures_dir / "Multiclass_models_confusion_matrix_on_test_with_percentage_agreement"
    png_path = out_base.with_suffix(".png")
    pdf_path = out_base.with_suffix(".pdf")

    # ---------- PLOT ----------
    n_cols = len(labels_pred)
    n_rows = len(labels_true)

    fig_w = max(5, 0.6 * n_cols + 2)
    fig_h = max(5, 0.6 * n_rows + 2)
    fig, ax = plt.subplots(figsize=(fig_w, fig_h), constrained_layout=False)

    im = ax.imshow(data, cmap="magma", aspect="auto", interpolation="nearest")

    ax.set_xticks(np.arange(n_cols))
    ax.set_yticks(np.arange(n_rows))
    ax.set_xticklabels(labels_pred, rotation=45, ha="right", fontsize=20)
    ax.set_yticklabels(labels_true, fontsize=20)
    ax.tick_params(axis="both", which="major", labelsize=20)

    ax.set_xlim(-0.5, n_cols - 0.5)
    ax.set_ylim(n_rows - 0.5, -0.5)

    ax.set_xlabel("")
    ax.set_ylabel("")

    ax.set_xticks(np.arange(-0.5, n_cols, 1.0), minor=True)
    ax.set_yticks(np.arange(-0.5, n_rows, 1.0), minor=True)
    ax.grid(which="minor", color="#373737", linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="2.5%", pad=0.02)
    cb = plt.colorbar(im, cax=cax)
    cb.set_label(cbar_label, fontsize=20)
    cb.ax.tick_params(labelsize=20, length=3, width=0.6)
    for spine in cax.spines.values():
        spine.set_linewidth(0.6)

    for spine in ax.spines.values():
        spine.set_edgecolor("#000000")
        spine.set_linewidth(3)

    plt.tight_layout()
    plt.savefig(png_path, dpi=300)
    plt.savefig(pdf_path)
    plt.close(fig)

    print(f"[SAVE] {png_path}")
    print(f"[SAVE] {pdf_path}")

# Infer assay root
# cm_csv_path: .../<Assay>/Release/Hao/Reports/Detailed/metrics/<file.csv>
# assay_root should be: .../<Assay>/Release
assay_root = cm_csv_path.parents[5]

for author in authors_to_do:
    csv = (
        assay_root
        / author
        / "Release"
        / name_target_class
        / "Reports"
        / "metrics"
        / "Multiclass_models_confusion_matrix_on_test.csv"
    )
    if not csv.exists():
        print(f"[SKIP] Missing: {csv}")
        continue
    render_heatmap(csv, normalize=NORMALIZE, cbar_label=CBAR_LABEL)
