# Set-up

In [None]:
import os

script_dir = os.path.dirname(os.path.realpath('__file__'))
parent_dir = os.path.dirname(script_dir)

## Importing modules

In [None]:
import scanpy as sc
import pandas as pd
import numpy as np
import os
import anndata
import random
import matplotlib.pyplot as plt
import glob
import harmonypy as hm
import seaborn as sns
import espressopro as ep

##

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn import preprocessing
import pickle

kf = KFold(n_splits=5, shuffle=True, random_state=42)

##

import multiprocessing

# Get the number of CPU cores
num_cores = multiprocessing.cpu_count()-2

print(f"Total CPU cores to be used: {num_cores}")

In [None]:
import warnings
warnings.filterwarnings('ignore')

Loading custom scripts

In [None]:
import sys
sys.path.append(parent_dir + '/Scripts/SingleCellProcessing')

import SCUtils

import sys
sys.path.append(parent_dir + '/Scripts/ModelTraining')

import MLTraining

In [None]:
def assign_labels(dataset, reduction, n_neighbors, label_input, label_output, frequency_threshold):
    # Compute the neighborhood graph
    sc.pp.neighbors(dataset, use_rep=reduction, n_neighbors=n_neighbors)

    # Perform the clustering
    sc.tl.leiden(dataset, key_added='clusters', resolution=10)

    # Initialize the new column with the existing labels
    dataset.obs[label_output] = dataset.obs[label_input]

    # For each cluster, find the most frequent label and assign it to all cells in the cluster
    for cluster in dataset.obs['clusters'].unique():
        cluster_labels = dataset.obs.loc[dataset.obs['clusters'] == cluster, label_input]
        most_frequent_label = cluster_labels.mode()[0]
        frequency = (cluster_labels == most_frequent_label).mean()

        if frequency > frequency_threshold:
            dataset.obs.loc[dataset.obs['clusters'] == cluster, label_output] = most_frequent_label

    return dataset

In [None]:
pip list

PYTHONHASHSEED was set as envinronmental variable to 0 as follows:
    
    conda env config vars set PYTHONHASHSEED=0

In [None]:
os.environ['PYTHONHASHSEED'] = '0'
random.seed(42)
np.random.seed(42)

In [None]:
def ensure_pythonhashseed(seed=0):
    current_seed = os.environ.get("PYTHONHASHSEED")

    seed = str(seed)
    if current_seed is None or current_seed != seed:
        print(f'Setting PYTHONHASHSEED="{seed}"')
        os.environ["PYTHONHASHSEED"] = seed
        # restart the current process
        os.execl(sys.executable, sys.executable, *sys.argv)

In [None]:
import random

hash = random.getrandbits(128)

print("hash value: %032x" % hash)

## Defining data path

In [None]:
# Specify the folder path
data_path = parent_dir + "/Data"
figures_path = parent_dir + "/Figures"

# Create the folder
os.makedirs(data_path + "/Pre_trained_models", exist_ok=True)

# Loading datasets

In [None]:
train_barcodes_path = data_path + "/Training_barcodes"
train_barcodes_path
test_barcodes_path = data_path + "/Testing_barcodes"
test_barcodes_path

## Loading Hao Y. et al. (2021) dataset

In [None]:
Hao_dataset_Train = sc.read_h5ad(data_path + "/References/Hao" + "/228AB_healthy_donors_PBMNCs_annotated_Train.h5ad")
Hao_dataset_Test = sc.read_h5ad(data_path + "/References/Hao" + "/228AB_healthy_donors_PBMNCs_annotated_Test.h5ad")
Hao_dataset_Cal = sc.read_h5ad(data_path + "/References/Hao" + "/228AB_healthy_donors_PBMNCs_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Hao_dataset_Train

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rc_context
import scanpy as sc

# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_wnn.umap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'gdT': "#EDB416",
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'dnT': "#504423",
    'Treg': "#6E6C37",
    'Platelet': "#FF39A6",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Hao_dataset_Train.obs[label_key]):
    Hao_dataset_Train.obs[label_key] = Hao_dataset_Train.obs[label_key].astype('category')

cats = list(Hao_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Hao_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5.2, 4.5)}):
    sc.pl.embedding(
        Hao_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=True,
        title='Hao',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.show()


In [None]:
Hao_dataset_Train.X = ep.Normalise_protein_data(Hao_dataset_Train.X)
Hao_dataset_Test.X = ep.Normalise_protein_data(Hao_dataset_Test.X)
Hao_dataset_Cal.X = ep.Normalise_protein_data(Hao_dataset_Cal.X)

In [None]:
sc.tl.rank_genes_groups(Hao_dataset_Train, 'Consensus_annotation_detailed_final', method='wilcoxon')
sc.pl.rank_genes_groups(Hao_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.rank_genes_groups_matrixplot(Hao_dataset_Train)

In [None]:
Hao_dataset_Train

In [None]:
sc.pl.violin(Hao_dataset_Train, keys='CD56', groupby='celltype.l2', rotation=90, use_raw=False)

In [None]:
sc.pl.violin(Hao_dataset_Train, keys='CD56', groupby='Consensus_annotation_detailed_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Hao_data_Train = pd.DataFrame(Hao_dataset_Train.X, index=Hao_dataset_Train.obs_names, columns=Hao_dataset_Train.var_names)
Hao_data_Test = pd.DataFrame(Hao_dataset_Test.X, index=Hao_dataset_Test.obs_names, columns=Hao_dataset_Test.var_names)
Hao_data_Cal = pd.DataFrame(Hao_dataset_Cal.X, index=Hao_dataset_Cal.obs_names, columns=Hao_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgD","IgM", "Rag-IgG2c"]
Hao_data_Train = Hao_data_Train.drop(columns=columns_to_drop, errors='ignore')
Hao_data_Test = Hao_data_Test.drop(columns=columns_to_drop, errors='ignore')
Hao_data_Cal = Hao_data_Cal.drop(columns=columns_to_drop, errors='ignore')

## Loading Zhang X. et al. (2024) dataset

In [None]:
Zhang_dataset_Train = sc.read_h5ad(data_path + "/References/Zhang" + "/Zhang_adata_annotated_Train.h5ad")
Zhang_dataset_Test = sc.read_h5ad(data_path + "/References/Zhang" + "/Zhang_adata_annotated_Test.h5ad")
Zhang_dataset_Cal = sc.read_h5ad(data_path + "/References/Zhang" + "/Zhang_adata_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Zhang_dataset_Train

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rc_context
import scanpy as sc

# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_umap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'gdT': "#EDB416",
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'dnT': "#504423",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Zhang_dataset_Train.obs[label_key]):
    Zhang_dataset_Train.obs[label_key] = Zhang_dataset_Train.obs[label_key].astype('category')

cats = list(Zhang_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Zhang_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5.5, 4.5)}):
    sc.pl.embedding(
        Zhang_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=True,
        title='Zhang',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.show()


In [None]:
Zhang_dataset_Train.X = ep.Normalise_protein_data(Zhang_dataset_Train.X)
Zhang_dataset_Test.X = ep.Normalise_protein_data(Zhang_dataset_Test.X)
Zhang_dataset_Cal.X = ep.Normalise_protein_data(Zhang_dataset_Cal.X)

In [None]:
sc.tl.rank_genes_groups(Zhang_dataset_Train, 'Consensus_annotation_detailed_final', method='wilcoxon')
sc.pl.rank_genes_groups(Zhang_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.violin(Zhang_dataset_Train, keys='CD123', groupby='Consensus_annotation_broad_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Zhang_data_Train = pd.DataFrame(Zhang_dataset_Train.X, index=Zhang_dataset_Train.obs_names, columns=Zhang_dataset_Train.var_names)
Zhang_data_Test = pd.DataFrame(Zhang_dataset_Test.X, index=Zhang_dataset_Test.obs_names, columns=Zhang_dataset_Test.var_names)
Zhang_data_Cal = pd.DataFrame(Zhang_dataset_Cal.X, index=Zhang_dataset_Cal.obs_names, columns=Zhang_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgG.Fc", "Isotype_G0114F7", "Isotype_HTK888",
                   "Isotype_MOPC.173", "Isotype_MOPC.21", "Isotype_MPC.11",
                   "Isotype_RTK2071", "Isotype_RTK2758", "Isotype_RTK4174",
                   "Isotype_RTK4530"]
Zhang_data_Train = Zhang_data_Train.drop(columns=columns_to_drop, errors='ignore')
Zhang_data_Test = Zhang_data_Test.drop(columns=columns_to_drop, errors='ignore')
Zhang_data_Cal = Zhang_data_Cal.drop(columns=columns_to_drop, errors='ignore')

## Loading Triana S. et al. (2021) dataset

In [None]:
Triana_dataset_Train = sc.read_h5ad(data_path + "/References/Triana" + "/97AB_young_and_old_adult_healthy_donor_BMMNCs_annotated_Train.h5ad")
Triana_dataset_Test = sc.read_h5ad(data_path + "/References/Triana" + "/97AB_young_and_old_adult_healthy_donor_BMMNCs_annotated_Test.h5ad")
Triana_dataset_Cal = sc.read_h5ad(data_path + "/References/Triana" + "/97AB_young_and_old_adult_healthy_donor_BMMNCs_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Triana_dataset_Train

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rc_context
import scanpy as sc

# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_mofaumap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'gdT': "#EDB416",
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'dnT': "#504423",
    'Treg': "#6E6C37",
    'Platelet': "#FF39A6",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Triana_dataset_Train.obs[label_key]):
    Triana_dataset_Train.obs[label_key] = Triana_dataset_Train.obs[label_key].astype('category')

cats = list(Triana_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Triana_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5.25, 4.5)}):
    sc.pl.embedding(
        Triana_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=True,
        title='Triana',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.show()


In [None]:
Triana_dataset_Train.X = ep.Normalise_protein_data(Triana_dataset_Train.X)
Triana_dataset_Test.X = ep.Normalise_protein_data(Triana_dataset_Test.X)
Triana_dataset_Cal.X = ep.Normalise_protein_data(Triana_dataset_Cal.X)

In [None]:
sc.tl.rank_genes_groups(Triana_dataset_Train, 'Consensus_annotation_simplified_final', method='wilcoxon')
sc.pl.rank_genes_groups(Triana_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.violin(Triana_dataset_Train, keys='CD133', groupby='Consensus_annotation_detailed_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Triana_data_Train = pd.DataFrame(Triana_dataset_Train.X, index=Triana_dataset_Train.obs_names, columns=Triana_dataset_Train.var_names)
Triana_data_Test = pd.DataFrame(Triana_dataset_Test.X, index=Triana_dataset_Test.obs_names, columns=Triana_dataset_Test.var_names)
Triana_data_Cal = pd.DataFrame(Triana_dataset_Cal.X, index=Triana_dataset_Cal.obs_names, columns=Triana_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgG", "IgD"]
Triana_data_Train = Triana_data_Train.drop(columns=columns_to_drop, errors='ignore')
Triana_data_Test = Triana_data_Test.drop(columns=columns_to_drop, errors='ignore')
Triana_data_Cal = Triana_data_Cal.drop(columns=columns_to_drop, errors='ignore')

## Loading Luecken M.D. et al. (2021) dataset

In [None]:
Luecken_dataset_Train = sc.read_h5ad(data_path + "/References/Luecken" + "/140AB_adult_healthy_donor_BMMNCs_annotated_Train.h5ad")
Luecken_dataset_Test = sc.read_h5ad(data_path + "/References/Luecken" + "/140AB_adult_healthy_donor_BMMNCs_annotated_Test.h5ad")
Luecken_dataset_Cal = sc.read_h5ad(data_path + "/References/Luecken" + "/140AB_adult_healthy_donor_BMMNCs_annotated_Cal.h5ad")

### Dataset Description

In [None]:
Luecken_dataset_Train

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import rc_context
import scanpy as sc

# --- Config ---
label_key = 'Consensus_annotation_detailed_final'
basis_key = 'X_umap'
color_key = f'{label_key}_colors'

# Your custom palette (label -> hex)
custom_palette = {
    'B Memory': "#68D827",
    'B Naive': '#1C511D',
    'CD14 Mono': "#D27CE3",
    'CD16 Mono': "#8D43CD",
    'CD4 T Memory': "#C1AF93",
    'CD4 T Naive': "#C99546",
    'CD8 T Memory': "#6B3317",
    'CD8 T Naive': "#4D382E",
    'ErP': "#D1235A",
    'Erythroblast': "#F30A1A",
    'GMP': "#C5E4FF",
    'HSC_MPP': '#0079ea',
    'Immature B': "#91FF7B",
    'LMPP': "#17BECF",
    'MAIT': "#BCBD22",
    'Myeloid progenitor': "#AEC7E8",
    'NK CD56 bright': "#F3AC1F",
    'NK CD56 dim': "#FBEF0D",
    'Plasma': "#9DC012",
    'Pro-B': "#66BB6A",
    'Small': "#292929",
    'cDC1': "#76A7CB",
    'cDC2': "#16D2E3",
    'gdT': "#EDB416",
    'pDC': "#69FFCB",
    'CD4 CTL': "#D7D2CB",
    'MEP': "#E364B0",
    'Pre-B': "#2DBD67",
    'Pre-Pro-B': '#92AC8E',
    'EoBaMaP': "#728245",
    'MkP': "#69424D",
    'Stroma': "#727272",
    'Macrophage': "#5F4761",
    'ILC': "#F7CF94",
    'dnT': "#504423",
    'Treg': "#6E6C37",
    'Platelet': "#FF39A6",
}

# --- Ensure categorical dtype ---
if not pd.api.types.is_categorical_dtype(Luecken_dataset_Train.obs[label_key]):
    Luecken_dataset_Train.obs[label_key] = Luecken_dataset_Train.obs[label_key].astype('category')

cats = list(Luecken_dataset_Train.obs[label_key].cat.categories)

# --- Sanity checks: missing/extra labels ---
labels_in_palette = set(custom_palette.keys())
labels_in_data = set(cats)

missing_in_palette = [c for c in cats if c not in labels_in_palette]
extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]

if missing_in_palette:
    print("[WARN] Missing colors for:", missing_in_palette, "-> will use light grey (#cccccc).")
if extra_in_palette:
    print("[INFO] Palette has unused entries:", extra_in_palette)

# --- Build palette list in *category order* ---
fallback = '#cccccc'
palette_list = [custom_palette.get(c, fallback) for c in cats]

# --- Save onto .uns so Scanpy uses it consistently elsewhere ---
Luecken_dataset_Train.uns[color_key] = palette_list

# --- Plot with Scanpy using built-in outlines (no clustering) ---
with rc_context({"figure.figsize": (5, 4.5)}):
    sc.pl.embedding(
        Luecken_dataset_Train,
        basis=basis_key,
        color=label_key,
        palette=palette_list,
        legend_loc='on data',
        legend_fontsize=10,
        legend_fontoutline=1.5,
        size=10,
        add_outline=True,   # built-in group outlines
        frameon=True,
        title='Luecken',
        show=False,
    )
    ax = plt.gca()
    ax.set_xlabel('')
    ax.set_ylabel('')
    plt.tight_layout()
    plt.show()


In [None]:
Luecken_dataset_Train.X = ep.Normalise_protein_data(Luecken_dataset_Train.X)
Luecken_dataset_Test.X = ep.Normalise_protein_data(Luecken_dataset_Test.X)
Luecken_dataset_Cal.X = ep.Normalise_protein_data(Luecken_dataset_Cal.X)

In [None]:
sc.tl.rank_genes_groups(Luecken_dataset_Train, 'Consensus_annotation_detailed_final', method='wilcoxon')
sc.pl.rank_genes_groups(Luecken_dataset_Train, n_genes=10, sharey=False, ncols = 3, fontsize = 14)

In [None]:
sc.pl.violin(Luecken_dataset_Train, keys='CD49b', groupby='Consensus_annotation_detailed_final', rotation=90, use_raw=False)

### ML pre-processing

In [None]:
Luecken_data_Train = pd.DataFrame(Luecken_dataset_Train.X, index=Luecken_dataset_Train.obs_names, columns=Luecken_dataset_Train.var_names)
Luecken_data_Test = pd.DataFrame(Luecken_dataset_Test.X, index=Luecken_dataset_Test.obs_names, columns=Luecken_dataset_Test.var_names)
Luecken_data_Cal = pd.DataFrame(Luecken_dataset_Cal.X, index=Luecken_dataset_Cal.obs_names, columns=Luecken_dataset_Cal.var_names)

In [None]:
# Assuming these are the columns to be dropped
columns_to_drop = ["IgG", "IgM", "IgD"]
Luecken_data_Train = Luecken_data_Train.drop(columns=columns_to_drop, errors='ignore')
Luecken_data_Test = Luecken_data_Test.drop(columns=columns_to_drop, errors='ignore')
Luecken_data_Cal = Luecken_data_Cal.drop(columns=columns_to_drop, errors='ignore')

# Loading antibodies panels

In [None]:
TotalSeqD_Heme_Oncology_CAT399906 = pd.read_csv(data_path + "/Antibodies_panels/TotalSeqD_Heme_Oncology_CAT399906.csv", index_col=0).index

# TotalSeqD Heme Oncology CAT399906 Models Training

In [None]:
data_path = data_path + '/Pre_trained_models/TotalSeqD_Heme_Oncology_CAT399906'
os.makedirs(data_path, exist_ok=True)


## Hao Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Hao", exist_ok=True)
os.makedirs(data_path + "/Hao/Models", exist_ok=True)

models_output = data_path + "/Hao"

### ML Training

In [None]:
Hao_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Hao_data_Train, Hao_data_Test, Hao_data_Cal          (DataFrames indexed by barcode)
#   Hao_dataset_Train, Hao_dataset_Test, Hao_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Broad"   # "Broad" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Hao_data_Train = attach_celltype(Hao_data_Train, Hao_dataset_Train, consensus_field)
Hao_data_Test  = attach_celltype(Hao_data_Test,  Hao_dataset_Test,  consensus_field)
Hao_data_Cal   = attach_celltype(Hao_data_Cal,   Hao_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Hao_data_Train = _rename_data_to_panel(Hao_data_Train)
Hao_data_Test  = _rename_data_to_panel(Hao_data_Test)
Hao_data_Cal   = _rename_data_to_panel(Hao_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Hao_data_Train = _panel_intersection(Hao_data_Train)
Hao_data_Test  = _panel_intersection(Hao_data_Test)
Hao_data_Cal   = _panel_intersection(Hao_data_Cal)

# ============================= FEATURES & LABELS =============================

Hao_data_Cal_lbl = Hao_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Cal.columns]

Hao_data_Train_Sub = Hao_data_Train.drop(columns=drop_cols_train, errors="ignore")
Hao_data_Test_Sub  = Hao_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Hao_data_Cal_Sub   = Hao_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Hao_data_Train_Sub.columns)
if list(Hao_data_Test_Sub.columns) != cols_train or list(Hao_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Hao_data_Train_Sub, "TRAIN")
_check_finite(Hao_data_Test_Sub,  "TEST")
_check_finite(Hao_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# Consistent class order
class_names  = sorted(pd.Series(Hao_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# Multiclass labels arrays
s_cal = Hao_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Hao_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Hao_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Hao_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Hao_data_Cal_Sub.copy()
X_te_all_df  = Hao_data_Test_Sub.copy()

# Preallocate OvR prob mats
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Hao_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Hao/Consensus_annotation_broad_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Hao_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Hao_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Hao/Consensus_annotation_broad_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Hao_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Hao_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Hao_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Hao_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Hao",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal, y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

probs_raw_df = pd.DataFrame(P_te,    index=test_index, columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc, index=test_index, columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Hao_data_Test["Celltype"].values
probs_compare["pred_raw"]      = P_te.argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc.argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS):")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation
y_pred_mc = P_te_mc.argmax(axis=1)
print("\nMulticlass classification report (TEST):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Hao_data_Train, Hao_data_Test, Hao_data_Cal          (DataFrames indexed by barcode)
#   Hao_dataset_Train, Hao_dataset_Test, Hao_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Simplified"   # "simplified" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Hao_data_Train = attach_celltype(Hao_data_Train, Hao_dataset_Train, consensus_field)
Hao_data_Test  = attach_celltype(Hao_data_Test,  Hao_dataset_Test,  consensus_field)
Hao_data_Cal   = attach_celltype(Hao_data_Cal,   Hao_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Hao_data_Train = _rename_data_to_panel(Hao_data_Train)
Hao_data_Test  = _rename_data_to_panel(Hao_data_Test)
Hao_data_Cal   = _rename_data_to_panel(Hao_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Hao_data_Train = _panel_intersection(Hao_data_Train)
Hao_data_Test  = _panel_intersection(Hao_data_Test)
Hao_data_Cal   = _panel_intersection(Hao_data_Cal)

# ============================= FEATURES & LABELS =============================

Hao_data_Cal_lbl = Hao_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Cal.columns]

Hao_data_Train_Sub = Hao_data_Train.drop(columns=drop_cols_train, errors="ignore")
Hao_data_Test_Sub  = Hao_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Hao_data_Cal_Sub   = Hao_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Hao_data_Train_Sub.columns)
if list(Hao_data_Test_Sub.columns) != cols_train or list(Hao_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Hao_data_Train_Sub, "TRAIN")
_check_finite(Hao_data_Test_Sub,  "TEST")
_check_finite(Hao_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma"}

all_classes = sorted(pd.Series(Hao_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Hao_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Hao_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Hao_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Hao_data_Test["Celltype"].isin(class_names)
s_te = Hao_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Hao_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Hao_data_Cal_Sub.copy()
X_te_all_df  = Hao_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Hao_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Hao/Consensus_annotation_simplified_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Hao_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Hao_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Hao/Consensus_annotation_simplified_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Hao_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Hao_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Hao_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Hao_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Hao",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Hao_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Hao_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Hao_data_Train, Hao_data_Test, Hao_data_Cal          (DataFrames indexed by barcode)
#   Hao_dataset_Train, Hao_dataset_Test, Hao_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Detailed"   # "detailed" | "detailed" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Hao_data_Train = attach_celltype(Hao_data_Train, Hao_dataset_Train, consensus_field)
Hao_data_Test  = attach_celltype(Hao_data_Test,  Hao_dataset_Test,  consensus_field)
Hao_data_Cal   = attach_celltype(Hao_data_Cal,   Hao_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Hao_data_Train = _rename_data_to_panel(Hao_data_Train)
Hao_data_Test  = _rename_data_to_panel(Hao_data_Test)
Hao_data_Cal   = _rename_data_to_panel(Hao_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Hao_data_Train = _panel_intersection(Hao_data_Train)
Hao_data_Test  = _panel_intersection(Hao_data_Test)
Hao_data_Cal   = _panel_intersection(Hao_data_Cal)

# ============================= FEATURES & LABELS =============================

Hao_data_Cal_lbl = Hao_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Hao_data_Cal.columns]

Hao_data_Train_Sub = Hao_data_Train.drop(columns=drop_cols_train, errors="ignore")
Hao_data_Test_Sub  = Hao_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Hao_data_Cal_Sub   = Hao_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Hao_data_Train_Sub.columns)
if list(Hao_data_Test_Sub.columns) != cols_train or list(Hao_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Hao_data_Train_Sub, "TRAIN")
_check_finite(Hao_data_Test_Sub,  "TEST")
_check_finite(Hao_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma", "dnT"}

all_classes = sorted(pd.Series(Hao_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Hao_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Hao_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Hao_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Hao_data_Test["Celltype"].isin(class_names)
s_te = Hao_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Hao_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Hao_data_Cal_Sub.copy()
X_te_all_df  = Hao_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Hao_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Hao/Consensus_annotation_detailed_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Hao_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Hao_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Hao/Consensus_annotation_detailed_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Hao_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Hao_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Hao_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Hao_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Hao",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Hao_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Hao_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


## Zhang Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Zhang", exist_ok=True)
os.makedirs(data_path + "/Zhang/Models", exist_ok=True)

models_output = data_path + "/Zhang/"

### ML Training

In [None]:
Zhang_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Zhang_data_Train, Zhang_data_Test, Zhang_data_Cal          (DataFrames indexed by barcode)
#   Zhang_dataset_Train, Zhang_dataset_Test, Zhang_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Broad"   # "Broad" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Zhang_data_Train = attach_celltype(Zhang_data_Train, Zhang_dataset_Train, consensus_field)
Zhang_data_Test  = attach_celltype(Zhang_data_Test,  Zhang_dataset_Test,  consensus_field)
Zhang_data_Cal   = attach_celltype(Zhang_data_Cal,   Zhang_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Zhang_data_Train = _rename_data_to_panel(Zhang_data_Train)
Zhang_data_Test  = _rename_data_to_panel(Zhang_data_Test)
Zhang_data_Cal   = _rename_data_to_panel(Zhang_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Zhang_data_Train = _panel_intersection(Zhang_data_Train)
Zhang_data_Test  = _panel_intersection(Zhang_data_Test)
Zhang_data_Cal   = _panel_intersection(Zhang_data_Cal)

# ============================= FEATURES & LABELS =============================

Zhang_data_Cal_lbl = Zhang_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Cal.columns]

Zhang_data_Train_Sub = Zhang_data_Train.drop(columns=drop_cols_train, errors="ignore")
Zhang_data_Test_Sub  = Zhang_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Zhang_data_Cal_Sub   = Zhang_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Zhang_data_Train_Sub.columns)
if list(Zhang_data_Test_Sub.columns) != cols_train or list(Zhang_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Zhang_data_Train_Sub, "TRAIN")
_check_finite(Zhang_data_Test_Sub,  "TEST")
_check_finite(Zhang_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# Consistent class order
class_names  = sorted(pd.Series(Zhang_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# Multiclass labels arrays
s_cal = Zhang_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Zhang_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Zhang_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Zhang_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Zhang_data_Cal_Sub.copy()
X_te_all_df  = Zhang_data_Test_Sub.copy()

# Preallocate OvR prob mats
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Zhang_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Zhang/Consensus_annotation_broad_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Zhang_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Zhang_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Zhang/Consensus_annotation_broad_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Zhang_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Zhang_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Zhang_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Zhang_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Zhang",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal, y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

probs_raw_df = pd.DataFrame(P_te,    index=test_index, columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc, index=test_index, columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Zhang_data_Test["Celltype"].values
probs_compare["pred_raw"]      = P_te.argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc.argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS):")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation
y_pred_mc = P_te_mc.argmax(axis=1)
print("\nMulticlass classification report (TEST):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Zhang_data_Train, Zhang_data_Test, Zhang_data_Cal          (DataFrames indexed by barcode)
#   Zhang_dataset_Train, Zhang_dataset_Test, Zhang_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Simplified"   # "simplified" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Zhang_data_Train = attach_celltype(Zhang_data_Train, Zhang_dataset_Train, consensus_field)
Zhang_data_Test  = attach_celltype(Zhang_data_Test,  Zhang_dataset_Test,  consensus_field)
Zhang_data_Cal   = attach_celltype(Zhang_data_Cal,   Zhang_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Zhang_data_Train = _rename_data_to_panel(Zhang_data_Train)
Zhang_data_Test  = _rename_data_to_panel(Zhang_data_Test)
Zhang_data_Cal   = _rename_data_to_panel(Zhang_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Zhang_data_Train = _panel_intersection(Zhang_data_Train)
Zhang_data_Test  = _panel_intersection(Zhang_data_Test)
Zhang_data_Cal   = _panel_intersection(Zhang_data_Cal)

# ============================= FEATURES & LABELS =============================

Zhang_data_Cal_lbl = Zhang_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Cal.columns]

Zhang_data_Train_Sub = Zhang_data_Train.drop(columns=drop_cols_train, errors="ignore")
Zhang_data_Test_Sub  = Zhang_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Zhang_data_Cal_Sub   = Zhang_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Zhang_data_Train_Sub.columns)
if list(Zhang_data_Test_Sub.columns) != cols_train or list(Zhang_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Zhang_data_Train_Sub, "TRAIN")
_check_finite(Zhang_data_Test_Sub,  "TEST")
_check_finite(Zhang_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma"}

all_classes = sorted(pd.Series(Zhang_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Zhang_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Zhang_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Zhang_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Zhang_data_Test["Celltype"].isin(class_names)
s_te = Zhang_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Zhang_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Zhang_data_Cal_Sub.copy()
X_te_all_df  = Zhang_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Zhang_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Zhang/Consensus_annotation_simplified_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Zhang_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Zhang_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Zhang/Consensus_annotation_simplified_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Zhang_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Zhang_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Zhang_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Zhang_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Zhang",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Zhang_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Zhang_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Zhang_data_Train, Zhang_data_Test, Zhang_data_Cal          (DataFrames indexed by barcode)
#   Zhang_dataset_Train, Zhang_dataset_Test, Zhang_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Detailed"   # "detailed" | "detailed" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Zhang_data_Train = attach_celltype(Zhang_data_Train, Zhang_dataset_Train, consensus_field)
Zhang_data_Test  = attach_celltype(Zhang_data_Test,  Zhang_dataset_Test,  consensus_field)
Zhang_data_Cal   = attach_celltype(Zhang_data_Cal,   Zhang_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Zhang_data_Train = _rename_data_to_panel(Zhang_data_Train)
Zhang_data_Test  = _rename_data_to_panel(Zhang_data_Test)
Zhang_data_Cal   = _rename_data_to_panel(Zhang_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Zhang_data_Train = _panel_intersection(Zhang_data_Train)
Zhang_data_Test  = _panel_intersection(Zhang_data_Test)
Zhang_data_Cal   = _panel_intersection(Zhang_data_Cal)

# ============================= FEATURES & LABELS =============================

Zhang_data_Cal_lbl = Zhang_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Zhang_data_Cal.columns]

Zhang_data_Train_Sub = Zhang_data_Train.drop(columns=drop_cols_train, errors="ignore")
Zhang_data_Test_Sub  = Zhang_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Zhang_data_Cal_Sub   = Zhang_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Zhang_data_Train_Sub.columns)
if list(Zhang_data_Test_Sub.columns) != cols_train or list(Zhang_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Zhang_data_Train_Sub, "TRAIN")
_check_finite(Zhang_data_Test_Sub,  "TEST")
_check_finite(Zhang_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma", "dnT"}

all_classes = sorted(pd.Series(Zhang_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Zhang_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Zhang_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Zhang_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Zhang_data_Test["Celltype"].isin(class_names)
s_te = Zhang_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Zhang_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Zhang_data_Cal_Sub.copy()
X_te_all_df  = Zhang_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Zhang_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Zhang/Consensus_annotation_detailed_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Zhang_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Zhang_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Zhang/Consensus_annotation_detailed_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Zhang_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Zhang_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Zhang_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Zhang_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Zhang",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Zhang_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Zhang_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


## Triana Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Triana", exist_ok=True)
os.makedirs(data_path + "/Triana/Models", exist_ok=True)

models_output = data_path + "/Triana/"

### ML Training

In [None]:
Triana_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Triana_data_Train, Triana_data_Test, Triana_data_Cal          (DataFrames indexed by barcode)
#   Triana_dataset_Train, Triana_dataset_Test, Triana_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Broad"   # "Broad" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Triana_data_Train = attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Triana_data_Train = _rename_data_to_panel(Triana_data_Train)
Triana_data_Test  = _rename_data_to_panel(Triana_data_Test)
Triana_data_Cal   = _rename_data_to_panel(Triana_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = _panel_intersection(Triana_data_Train)
Triana_data_Test  = _panel_intersection(Triana_data_Test)
Triana_data_Cal   = _panel_intersection(Triana_data_Cal)

# ============================= FEATURES & LABELS =============================

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Triana_data_Train_Sub, "TRAIN")
_check_finite(Triana_data_Test_Sub,  "TEST")
_check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# Consistent class order
class_names  = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# Multiclass labels arrays
s_cal = Triana_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Triana_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Triana_data_Cal_Sub.copy()
X_te_all_df  = Triana_data_Test_Sub.copy()

# Preallocate OvR prob mats
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Triana_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_broad_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_broad_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Triana_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Triana_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal, y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

probs_raw_df = pd.DataFrame(P_te,    index=test_index, columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc, index=test_index, columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Triana_data_Test["Celltype"].values
probs_compare["pred_raw"]      = P_te.argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc.argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS):")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation
y_pred_mc = P_te_mc.argmax(axis=1)
print("\nMulticlass classification report (TEST):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Triana_data_Train, Triana_data_Test, Triana_data_Cal          (DataFrames indexed by barcode)
#   Triana_dataset_Train, Triana_dataset_Test, Triana_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Simplified"   # "simplified" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Triana_data_Train = attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Triana_data_Train = _rename_data_to_panel(Triana_data_Train)
Triana_data_Test  = _rename_data_to_panel(Triana_data_Test)
Triana_data_Cal   = _rename_data_to_panel(Triana_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = _panel_intersection(Triana_data_Train)
Triana_data_Test  = _panel_intersection(Triana_data_Test)
Triana_data_Cal   = _panel_intersection(Triana_data_Cal)

# ============================= FEATURES & LABELS =============================

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Triana_data_Train_Sub, "TRAIN")
_check_finite(Triana_data_Test_Sub,  "TEST")
_check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma"}

all_classes = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Triana_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Triana_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Triana_data_Test["Celltype"].isin(class_names)
s_te = Triana_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Triana_data_Cal_Sub.copy()
X_te_all_df  = Triana_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Triana_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_simplified_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_simplified_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Triana_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Triana_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Triana_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Triana_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


In [None]:
from pathlib import Path
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib import rc_context

# --------- 1) Collect training barcodes (union across classes or a single class) ---------
def collect_training_barcodes(train_barcodes_path: str | Path,
                              atlas: str = "Triana",
                              depth: str = "simplified",
                              class_name: str | None = None) -> pd.Index:
    """
    Reads CSVs like:
      {train_barcodes_path}/{atlas}/Consensus_annotation_{depth}_final/Barcodes_training_class_{Class}.csv
    and returns the union of 'Positive' and 'Negative' barcodes.
    """
    base = Path(train_barcodes_path) / atlas / f"Consensus_annotation_{depth.lower()}_final"

    if class_name is None:
        files = sorted(base.glob("Barcodes_training_class_*.csv"))
        if not files:
            raise FileNotFoundError(f"No training CSVs found in: {base}")
    else:
        files = [base / f"Barcodes_training_class_{str(class_name).replace(' ', '_')}.csv"]
        if not files[0].exists():
            raise FileNotFoundError(f"Training CSV not found: {files[0]}")

    barcodes = []
    for fp in files:
        df = pd.read_csv(fp, index_col=0)
        for col in ("Positive", "Negative"):
            if col in df.columns:
                barcodes.extend(df[col].dropna().astype(str).tolist())

    return pd.Index(pd.unique(barcodes))


# --------- 2) Plot UMAP for the subset of training barcodes ---------
def plot_training_subset_umap(
    adata,                                 # Triana_dataset_Train (AnnData)
    train_barcodes_path: str | Path,
    *,
    atlas: str = "Triana",
    depth: str = "simplified",             # matches your CSV folder name
    class_name: str | None = None,         # None -> all classes; or "B Memory", etc.
    label_key: str = "Consensus_annotation_detailed_final",
    basis_key: str = "X_mofaumap",
    custom_palette: dict | None = None,    # your dict {label: "#hex"}
    point_size: float = 10.0,
    title: str | None = None,
):
    # 1) get training barcodes
    keep_barcodes = collect_training_barcodes(train_barcodes_path, atlas=atlas, depth=depth, class_name=class_name)

    # 2) subset AnnData
    n_before = adata.n_obs
    mask = adata.obs_names.isin(keep_barcodes)
    n_after = int(mask.sum())
    if n_after == 0:
        raise ValueError("No training barcodes overlapped with adata.obs_names.")
    ad_sub = adata[mask].copy()
    print(f"[subset] kept {n_after}/{n_before} cells for plotting "
          f"({'ALL classes' if class_name is None else f'class={class_name}'})")

    # 3) ensure categorical for coloring
    if not pd.api.types.is_categorical_dtype(ad_sub.obs[label_key]):
        ad_sub.obs[label_key] = ad_sub.obs[label_key].astype("category")
    cats = list(ad_sub.obs[label_key].cat.categories)

    # 4) palette handling (use your custom palette where available; fallback to grey)
    fallback = "#cccccc"
    if custom_palette is None:
        palette_list = None  # let scanpy pick
    else:
        labels_in_palette = set(custom_palette.keys())
        labels_in_data = set(cats)
        missing_in_palette = [c for c in cats if c not in labels_in_palette]
        extra_in_palette   = [c for c in custom_palette.keys() if c not in labels_in_data]
        if missing_in_palette:
            print("[WARN] Missing colors for:", missing_in_palette, "-> using light grey (#cccccc).")
        if extra_in_palette:
            print("[INFO] Palette has unused entries:", extra_in_palette)

        palette_list = [custom_palette.get(c, fallback) for c in cats]
        # stash into .uns so scanpy reuses it
        ad_sub.uns[f"{label_key}_colors"] = palette_list

    # 5) plot
    with rc_context({"figure.figsize": (5.75, 4.75)}):
        sc.pl.embedding(
            ad_sub,
            basis=basis_key,
            color=label_key,
            palette=palette_list,
            legend_loc="on data",
            legend_fontsize=10,
            legend_fontoutline=1.5,
            size=point_size,
            add_outline=True,
            frameon=True,
            title=title or f"Triana — training subset ({'all classes' if class_name is None else class_name})",
            show=True,
        )


# ---------------------- EXAMPLES ----------------------
# 1) Plot ALL training barcodes (union of every class) with your detailed labels + palette
# plot_training_subset_umap(
#     Triana_dataset_Train,
#     train_barcodes_path=train_barcodes_path,
#     atlas="Triana",
#     depth="simplified",
#     class_name=None,
#     label_key="Consensus_annotation_detailed_final",
#     basis_key="X_mofaumap",
#     custom_palette=custom_palette,  # pass the dict you defined
#     point_size=10,
#     title="Triana — all training cells",
# )

# 2) Plot only one class’s training barcodes (e.g., "B Memory")
plot_training_subset_umap(
    Triana_dataset_Train,
    train_barcodes_path=train_barcodes_path,
    atlas="Triana",
    depth="simplified",
    class_name="Erythroid",
    label_key="Consensus_annotation_detailed_final",
    basis_key="X_mofaumap",
    custom_palette=custom_palette,
    point_size=10,
    title="Triana — training cells: B Memory",
)


In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Triana_data_Train, Triana_data_Test, Triana_data_Cal          (DataFrames indexed by barcode)
#   Triana_dataset_Train, Triana_dataset_Test, Triana_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Simplified"   # "simplified" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Triana_data_Train = attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Triana_data_Train = _rename_data_to_panel(Triana_data_Train)
Triana_data_Test  = _rename_data_to_panel(Triana_data_Test)
Triana_data_Cal   = _rename_data_to_panel(Triana_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = _panel_intersection(Triana_data_Train)
Triana_data_Test  = _panel_intersection(Triana_data_Test)
Triana_data_Cal   = _panel_intersection(Triana_data_Cal)

# ============================= FEATURES & LABELS =============================

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Triana_data_Train_Sub, "TRAIN")
_check_finite(Triana_data_Test_Sub,  "TEST")
_check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# Consistent class order
class_names  = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# Multiclass labels arrays
s_cal = Triana_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Triana_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Triana_data_Cal_Sub.copy()
X_te_all_df  = Triana_data_Test_Sub.copy()

# Preallocate OvR prob mats
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Triana_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_simplified_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_simplified_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Triana_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Triana_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal, y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

probs_raw_df = pd.DataFrame(P_te,    index=test_index, columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc, index=test_index, columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Triana_data_Test["Celltype"].values
probs_compare["pred_raw"]      = P_te.argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc.argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS):")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation
y_pred_mc = P_te_mc.argmax(axis=1)
print("\nMulticlass classification report (TEST):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Triana_data_Train, Triana_data_Test, Triana_data_Cal          (DataFrames indexed by barcode)
#   Triana_dataset_Train, Triana_dataset_Test, Triana_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Detailed"   # "detailed" | "detailed" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Triana_data_Train = attach_celltype(Triana_data_Train, Triana_dataset_Train, consensus_field)
Triana_data_Test  = attach_celltype(Triana_data_Test,  Triana_dataset_Test,  consensus_field)
Triana_data_Cal   = attach_celltype(Triana_data_Cal,   Triana_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Triana_data_Train = _rename_data_to_panel(Triana_data_Train)
Triana_data_Test  = _rename_data_to_panel(Triana_data_Test)
Triana_data_Cal   = _rename_data_to_panel(Triana_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Triana_data_Train = _panel_intersection(Triana_data_Train)
Triana_data_Test  = _panel_intersection(Triana_data_Test)
Triana_data_Cal   = _panel_intersection(Triana_data_Cal)

# ============================= FEATURES & LABELS =============================

Triana_data_Cal_lbl = Triana_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Triana_data_Cal.columns]

Triana_data_Train_Sub = Triana_data_Train.drop(columns=drop_cols_train, errors="ignore")
Triana_data_Test_Sub  = Triana_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Triana_data_Cal_Sub   = Triana_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Triana_data_Train_Sub.columns)
if list(Triana_data_Test_Sub.columns) != cols_train or list(Triana_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Triana_data_Train_Sub, "TRAIN")
_check_finite(Triana_data_Test_Sub,  "TEST")
_check_finite(Triana_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma", "dnT"}

all_classes = sorted(pd.Series(Triana_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Triana_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Triana_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Triana_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Triana_data_Test["Celltype"].isin(class_names)
s_te = Triana_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Triana_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Triana_data_Cal_Sub.copy()
X_te_all_df  = Triana_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Triana_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Triana/Consensus_annotation_detailed_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Triana_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Triana_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Triana/Consensus_annotation_detailed_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Triana_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Triana_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Triana_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Triana_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Triana",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Triana_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Triana_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


## Luecken Models

In [None]:
# Create the folders
os.makedirs(data_path + "/Luecken", exist_ok=True)
os.makedirs(data_path + "/Luecken/Models", exist_ok=True)

models_output = data_path + "/Luecken/"

### ML Training

In [None]:
Luecken_Models = {}

#### Broad annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Luecken_data_Train, Luecken_data_Test, Luecken_data_Cal          (DataFrames indexed by barcode)
#   Luecken_dataset_Train, Luecken_dataset_Test, Luecken_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Broad"   # "Broad" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Luecken_data_Train = attach_celltype(Luecken_data_Train, Luecken_dataset_Train, consensus_field)
Luecken_data_Test  = attach_celltype(Luecken_data_Test,  Luecken_dataset_Test,  consensus_field)
Luecken_data_Cal   = attach_celltype(Luecken_data_Cal,   Luecken_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Luecken_data_Train = _rename_data_to_panel(Luecken_data_Train)
Luecken_data_Test  = _rename_data_to_panel(Luecken_data_Test)
Luecken_data_Cal   = _rename_data_to_panel(Luecken_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Luecken_data_Train = _panel_intersection(Luecken_data_Train)
Luecken_data_Test  = _panel_intersection(Luecken_data_Test)
Luecken_data_Cal   = _panel_intersection(Luecken_data_Cal)

# ============================= FEATURES & LABELS =============================

Luecken_data_Cal_lbl = Luecken_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Cal.columns]

Luecken_data_Train_Sub = Luecken_data_Train.drop(columns=drop_cols_train, errors="ignore")
Luecken_data_Test_Sub  = Luecken_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Luecken_data_Cal_Sub   = Luecken_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Luecken_data_Train_Sub.columns)
if list(Luecken_data_Test_Sub.columns) != cols_train or list(Luecken_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Luecken_data_Train_Sub, "TRAIN")
_check_finite(Luecken_data_Test_Sub,  "TEST")
_check_finite(Luecken_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# Consistent class order
class_names  = sorted(pd.Series(Luecken_data_Train["Celltype"]).dropna().unique())
K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# Multiclass labels arrays
s_cal = Luecken_data_Cal_lbl["Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Luecken_data_Cal_lbl.loc[s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

s_te = Luecken_data_Test["Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Luecken_data_Test.loc[s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Luecken_data_Cal_Sub.copy()
X_te_all_df  = Luecken_data_Test_Sub.copy()

# Preallocate OvR prob mats
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Luecken_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Luecken/Consensus_annotation_broad_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Luecken_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Luecken_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Luecken/Consensus_annotation_broad_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Luecken_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Luecken_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Luecken_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Luecken_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Luecken",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
ts_cal.fit(P_cal, y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

probs_raw_df = pd.DataFrame(P_te,    index=test_index, columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc, index=test_index, columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Luecken_data_Test["Celltype"].values
probs_compare["pred_raw"]      = P_te.argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc.argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS):")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation
y_pred_mc = P_te_mc.argmax(axis=1)
print("\nMulticlass classification report (TEST):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Simplified annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Luecken_data_Train, Luecken_data_Test, Luecken_data_Cal          (DataFrames indexed by barcode)
#   Luecken_dataset_Train, Luecken_dataset_Test, Luecken_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Simplified"   # "simplified" | "Simplified" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Luecken_data_Train = attach_celltype(Luecken_data_Train, Luecken_dataset_Train, consensus_field)
Luecken_data_Test  = attach_celltype(Luecken_data_Test,  Luecken_dataset_Test,  consensus_field)
Luecken_data_Cal   = attach_celltype(Luecken_data_Cal,   Luecken_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Luecken_data_Train = _rename_data_to_panel(Luecken_data_Train)
Luecken_data_Test  = _rename_data_to_panel(Luecken_data_Test)
Luecken_data_Cal   = _rename_data_to_panel(Luecken_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Luecken_data_Train = _panel_intersection(Luecken_data_Train)
Luecken_data_Test  = _panel_intersection(Luecken_data_Test)
Luecken_data_Cal   = _panel_intersection(Luecken_data_Cal)

# ============================= FEATURES & LABELS =============================

Luecken_data_Cal_lbl = Luecken_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Cal.columns]

Luecken_data_Train_Sub = Luecken_data_Train.drop(columns=drop_cols_train, errors="ignore")
Luecken_data_Test_Sub  = Luecken_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Luecken_data_Cal_Sub   = Luecken_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Luecken_data_Train_Sub.columns)
if list(Luecken_data_Test_Sub.columns) != cols_train or list(Luecken_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Luecken_data_Train_Sub, "TRAIN")
_check_finite(Luecken_data_Test_Sub,  "TEST")
_check_finite(Luecken_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma"}

all_classes = sorted(pd.Series(Luecken_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Luecken_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Luecken_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Luecken_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Luecken_data_Test["Celltype"].isin(class_names)
s_te = Luecken_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Luecken_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Luecken_data_Cal_Sub.copy()
X_te_all_df  = Luecken_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Luecken_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Luecken/Consensus_annotation_simplified_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Luecken_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Luecken_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Luecken/Consensus_annotation_simplified_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Luecken_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Luecken_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Luecken_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Luecken_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Luecken",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Luecken_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Luecken_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")


#### Detailed annotation

In [None]:
# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.ensemble import StackingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from netcal.scaling import TemperatureScaling
import joblib

# ------------------------------------------------------------------- CONFIG (expects these to already exist)
#   models_output, train_barcodes_path, test_barcodes_path
#   Luecken_data_Train, Luecken_data_Test, Luecken_data_Cal          (DataFrames indexed by barcode)
#   Luecken_dataset_Train, Luecken_dataset_Test, Luecken_dataset_Cal (AnnData with obs labels)
#   TotalSeqD_Heme_Oncology_CAT399906                    (iterable of feature names)
#   MLTraining module with: CV, train_NB, train_XGB, train_KNN, train_MLP,
#                           plot_calibration_curve, save_models, evaluate_classifier, append_metrics_csv

name_target_class = "Detailed"   # "detailed" | "detailed" | "Detailed"
fig_root   = Path(models_output) / "Figures"
models_dir = Path(models_output) / "Models"
fig_root.mkdir(parents=True, exist_ok=True)
models_dir.mkdir(parents=True, exist_ok=True)

kf         = MLTraining.CV
num_cores  = -1
metrics_log = []

# ============================= HELPERS =============================

def _norm_feats(names) -> pd.Index:
    """
    Normalizer used ONLY to construct matching keys.
    Panel names remain untouched; data columns are normalized and then mapped BACK
    to the exact panel names via a lookup.
    """
    s = pd.Index(map(str, names))
    s = (s.str.strip()
           .str.lower()
           .str.replace(r"[ _/]+", "-", regex=True)
           .str.replace(r"-+", "-", regex=True)
           .str.strip("-"))
    return s

def attach_celltype(df: pd.DataFrame, ad: "AnnData", field: str) -> pd.DataFrame:
    if field not in ad.obs:
        raise KeyError(f"'{field}' not found in AnnData.obs")
    lab = (ad.obs[field]
             .astype("string")
             .str.strip()
             .str.replace(r"\s+", "_", regex=True))
    out = df.copy()
    out["Celltype"] = pd.Categorical(lab.reindex(out.index))
    if out["Celltype"].isna().any():
        missing = int(out["Celltype"].isna().sum())
        print(f"[WARN] {missing} rows got NaN Celltype after reindex; check barcode alignment.")
    return out

def _check_finite(df: pd.DataFrame, tag: str):
    arr = df.to_numpy()
    if not np.isfinite(arr).all():
        bad = np.where(~np.isfinite(arr))
        raise ValueError(f"Non-finite values found in {tag} features at positions {bad}")

def _unwrap_estimator(m):
    return getattr(m, "estimator", None) or getattr(m, "base_estimator", None) or m

def _assert_feature_counts(cell_name: str, models_dict: dict, expected: int):
    pairs = [
        ("NB",  models_dict.get("NB")),
        ("XGB", models_dict.get("XGB")),
        ("KNN", models_dict.get("KNN")),
        ("MLP", models_dict.get("MLP")),
        ("Stacker", models_dict.get("Stacker")),
    ]
    for name, est in pairs:
        if est is None:
            continue
        base = _unwrap_estimator(est)
        nfi = getattr(base, "n_features_in_", None)
        if nfi is not None and nfi != expected:
            raise RuntimeError(f"{cell_name}:{name} saw {nfi} features; expected {expected}")

# ============================= LABEL ATTACH =============================

consensus_field = f"Consensus_annotation_{name_target_class.lower()}_final"

Luecken_data_Train = attach_celltype(Luecken_data_Train, Luecken_dataset_Train, consensus_field)
Luecken_data_Test  = attach_celltype(Luecken_data_Test,  Luecken_dataset_Test,  consensus_field)
Luecken_data_Cal   = attach_celltype(Luecken_data_Cal,   Luecken_dataset_Cal,   consensus_field)

# ============================= PANEL & DATA COLUMN ALIGNMENT =============================

# Keep the panel EXACTLY as provided
panel = pd.Index(map(str, TotalSeqD_Heme_Oncology_CAT399906))

# Build a mapping: normalized_key -> exact panel name
panel_keys    = _norm_feats(panel)
norm_to_panel = dict(zip(panel_keys, panel))
if len(norm_to_panel) != len(panel):
    raise ValueError("Panel contains names that collide after normalization. Consider adjusting _norm_feats rules.")

def _rename_data_to_panel(df: pd.DataFrame) -> pd.DataFrame:
    """
    Rename only feature columns so that after normalization they map
    back to the exact panel column names. Keeps 'cell_barcode' and 'Celltype' intact.
    """
    df = df.copy()
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat     = pd.Index([c for c in df.columns if c not in non_feat])

    feat_keys   = _norm_feats(feat)
    mapped      = [norm_to_panel.get(k) for k in feat_keys]  # None if not in panel
    rename_map  = {old: new for old, new in zip(feat, mapped) if new is not None}

    # Handle duplicate mappings (two data columns → same panel col). Keep first, drop the rest.
    seen, safe_map, drops = set(), {}, []
    for old, new in rename_map.items():
        if new in seen:
            drops.append(old)
        else:
            seen.add(new); safe_map[old] = new

    if drops:
        print(f"[WARN] Dropping {len(drops)} duplicated-mapped columns (showing up to 5): {drops[:5]}")

    if drops:
        df.drop(columns=drops, inplace=True, errors="ignore")
    df.rename(columns=safe_map, inplace=True)

    matched = len(safe_map)
    print(f"[map] matched {matched}/{len(feat)} data columns to panel")
    return df

# Apply: normalize/rename ONLY data splits (panel remains untouched)
Luecken_data_Train = _rename_data_to_panel(Luecken_data_Train)
Luecken_data_Test  = _rename_data_to_panel(Luecken_data_Test)
Luecken_data_Cal   = _rename_data_to_panel(Luecken_data_Cal)

# Intersect each split with the panel IN PANEL ORDER
def _panel_intersection(df: pd.DataFrame) -> pd.DataFrame:
    non_feat = [c for c in ["cell_barcode", "Celltype"] if c in df.columns]
    feat_cols = pd.Index([c for c in df.columns if c not in non_feat])
    inter = panel.intersection(feat_cols, sort=False)
    if inter.empty:
        raise ValueError("Panel/Data intersection is empty after renaming. Check mapping rules.")
    return df.reindex(columns=list(inter) + non_feat)

Luecken_data_Train = _panel_intersection(Luecken_data_Train)
Luecken_data_Test  = _panel_intersection(Luecken_data_Test)
Luecken_data_Cal   = _panel_intersection(Luecken_data_Cal)

# ============================= FEATURES & LABELS =============================

Luecken_data_Cal_lbl = Luecken_data_Cal[["Celltype"]].copy()

drop_cols_train = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Train.columns]
drop_cols_test  = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Test.columns]
drop_cols_cal   = [c for c in ["cell_barcode", "Celltype"] if c in Luecken_data_Cal.columns]

Luecken_data_Train_Sub = Luecken_data_Train.drop(columns=drop_cols_train, errors="ignore")
Luecken_data_Test_Sub  = Luecken_data_Test.drop(columns=drop_cols_test,  errors="ignore")
Luecken_data_Cal_Sub   = Luecken_data_Cal.drop(columns=drop_cols_cal,    errors="ignore")

# SAFETY: shared columns & finiteness checks
cols_train = list(Luecken_data_Train_Sub.columns)
if list(Luecken_data_Test_Sub.columns) != cols_train or list(Luecken_data_Cal_Sub.columns) != cols_train:
    raise ValueError("Train/Cal/Test feature columns differ after panel intersection!")

_check_finite(Luecken_data_Train_Sub, "TRAIN")
_check_finite(Luecken_data_Test_Sub,  "TEST")
_check_finite(Luecken_data_Cal_Sub,   "CAL")

print(f"\n[features] Using {len(cols_train)} panel-intersected features (exact panel names):")
print(cols_train)

# ===== Exclude specific classes from the multiclass set and per-class loop =====
EXCLUDE_CLASSES = {"Macrophage", "ILC", "Stroma", "dnT"}

all_classes = sorted(pd.Series(Luecken_data_Train["Celltype"]).dropna().unique())
class_names = [c for c in all_classes if c not in EXCLUDE_CLASSES]
if not class_names:
    raise ValueError("After exclusions, class_names is empty.")
print(f"[classes] Included ({len(class_names)}): {class_names}")
if missing := [c for c in all_classes if c in EXCLUDE_CLASSES]:
    print(f"[classes] Excluded: {missing}")

K            = len(class_names)
class_to_idx = {c: i for i, c in enumerate(class_names)}

# --- Multiclass labels (MASKED to included classes) ---
# CAL
mask_cal_mc = Luecken_data_Cal_lbl["Celltype"].isin(class_names)
s_cal = Luecken_data_Cal_lbl.loc[mask_cal_mc, "Celltype"].map(class_to_idx)
if s_cal.isna().any():
    missing = Luecken_data_Cal_lbl.loc[mask_cal_mc & s_cal.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked CAL split: {missing}")
y_cal_multiclass = s_cal.to_numpy(dtype=np.int64)

# TEST
mask_test_mc = Luecken_data_Test["Celltype"].isin(class_names)
s_te = Luecken_data_Test.loc[mask_test_mc, "Celltype"].map(class_to_idx)
if s_te.isna().any():
    missing = Luecken_data_Test.loc[mask_test_mc & s_te.isna(), "Celltype"].unique()
    raise ValueError(f"Unknown labels in masked TEST split: {missing}")
y_test_multiclass = s_te.to_numpy(dtype=np.int64)

# Reuse across classes
X_cal_all_df = Luecken_data_Cal_Sub.copy()
X_te_all_df  = Luecken_data_Test_Sub.copy()

# Preallocate OvR prob mats (only for included classes)
P_cal = np.zeros((X_cal_all_df.shape[0], K), dtype=float)
P_te  = np.zeros((X_te_all_df.shape[0],  K), dtype=float)

test_index = Luecken_data_Test_Sub.index

# ============================= TRAIN PER-CLASS OVR =============================

for celltype in class_names:
    k = class_to_idx[celltype]
    name = str(celltype).replace(" ", "_")
    print(f"\nProcessing {name} (class {k+1}/{K})...")

    # ---- TRAIN slice via barcode lists
    train_barcodes_df = pd.read_csv(
        f"{train_barcodes_path}/Luecken/Consensus_annotation_detailed_final/Barcodes_training_class_{name}.csv",
        index_col=0
    )
    train_positive_barcodes = train_barcodes_df["Positive"].dropna().values
    train_negative_barcodes = train_barcodes_df["Negative"].dropna().values
    all_train_barcodes = np.concatenate([train_positive_barcodes, train_negative_barcodes])

    train_mask = Luecken_data_Train_Sub.index.isin(all_train_barcodes)
    X_tr_df = Luecken_data_Train_Sub.loc[train_mask]
    found_train_barcodes = X_tr_df.index.values
    y_tr = np.isin(found_train_barcodes, train_positive_barcodes).astype(int)

    # ---- Skip guards
    if X_tr_df.empty or np.unique(y_tr).size < 2:
        print(f"[SKIP] {name}: empty or single-class train slice (pos={y_tr.sum()}, neg={(len(y_tr)-y_tr.sum())}).")
        continue

    # ---- TEST slice via barcode lists
    test_barcodes_df = pd.read_csv(
        f"{test_barcodes_path}/Luecken/Consensus_annotation_detailed_final/Barcodes_testing_class_{name}.csv",
        index_col=0
    )
    test_positive_barcodes = test_barcodes_df["Positive"].dropna().values
    test_negative_barcodes = test_barcodes_df["Negative"].dropna().values
    all_test_barcodes = np.concatenate([test_positive_barcodes, test_negative_barcodes])

    test_mask = Luecken_data_Test_Sub.index.isin(all_test_barcodes)
    X_te_df = Luecken_data_Test_Sub.loc[test_mask]
    found_test_barcodes = X_te_df.index.values
    y_te = np.isin(found_test_barcodes, test_positive_barcodes).astype(int)

    # ---- Full-test & cal for this binary head
    X_te_all_local = X_te_all_df.copy()
    y_te_all = (Luecken_data_Test["Celltype"].values == celltype).astype(int)
    X_cal_df = X_cal_all_df.copy()
    y_cal_bin = (Luecken_data_Cal_lbl["Celltype"].values == celltype).astype(int)

    # ---- Info
    print(f"Training - Found {X_tr_df.shape[0]} / {len(all_train_barcodes)} barcodes")
    print(f"Training - Pos: {len(train_positive_barcodes)}, Neg: {len(train_negative_barcodes)}")
    print(f"Training labels: {y_tr.sum()} pos, {len(y_tr)-y_tr.sum()} neg")
    print(f"Testing  - Found {X_te_df.shape[0]} / {len(all_test_barcodes)} barcodes")
    print(f"Testing  - Pos: {len(test_positive_barcodes)}, Neg: {len(test_negative_barcodes)}")
    print(f"Testing  - labels: {y_te.sum()} pos, {len(y_te)-y_te.sum()} neg")
    print(f"Calibrating - Found {X_cal_df.shape[0]} rows | Pos: {y_cal_bin.sum()}, Neg: {len(y_cal_bin)-y_cal_bin.sum()}")
    print(f"All test data: {X_te_all_local.shape[0]} rows, positives for {celltype}: {y_te_all.sum()}")

    # ---- Scaling (fit on per-head TRAIN slice; transform others)
    scaler = StandardScaler(with_mean=True, with_std=True).fit(X_tr_df.values)

    def _sc(df):
        return pd.DataFrame(
            scaler.transform(df.values),
            index=df.index,
            columns=cols_train,
        )

    X_tr_sc_df     = _sc(X_tr_df)
    X_te_sc_df     = _sc(X_te_df)
    X_te_all_sc_df = _sc(X_te_all_local)
    X_cal_sc_df    = _sc(X_cal_df)

    print(f"[scale] {name}: train mean ~ {X_tr_sc_df.values.mean():.3f}, std ~ {X_tr_sc_df.values.std():.3f}")

    # ---- Base learners
    NB_model  = MLTraining.train_NB (X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    XGB_model = MLTraining.train_XGB(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    KNN_model = MLTraining.train_KNN(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)
    MLP_model = MLTraining.train_MLP(X_tr_sc_df, y_tr, cv=kf, num_cores=num_cores, name_target_subclass=name)

    # ---- Stacker (raw)
    stacker_raw = StackingClassifier(
        estimators=[("NB", NB_model), ("XGB", XGB_model), ("KNN", KNN_model), ("MLP", MLP_model)],
        final_estimator=LogisticRegression(max_iter=2000, class_weight="balanced", random_state=42),
        stack_method="predict_proba",
        cv=kf,
        n_jobs=-1,
    ).fit(X_tr_sc_df, y_tr)

    # ---- Feature count asserts (debug safety)
    expected_feats = len(cols_train)
    _assert_feature_counts(name, {
        "NB": NB_model, "XGB": XGB_model, "KNN": KNN_model, "MLP": MLP_model, "Stacker": stacker_raw
    }, expected_feats)

    # ---- Binary calibration (Platt, guarded)
    pos_cal    = int(y_cal_bin.sum())
    n_cal_bin  = int(len(y_cal_bin))
    has_both   = (0 < pos_cal < n_cal_bin)
    print(f"[CAL] {name}: cal positives={pos_cal}/{n_cal_bin}")

    if has_both:
        try:
            calibrator = CalibratedClassifierCV(estimator=stacker_raw, method="sigmoid", cv="prefit")
        except TypeError:  # older sklearn
            calibrator = CalibratedClassifierCV(base_estimator=stacker_raw, method="sigmoid", cv="prefit")
        stacker = calibrator.fit(X_cal_sc_df, y_cal_bin)
    else:
        print(f"[WARN] Skipping calibration for {name}: single-class cal set.")
        stacker = stacker_raw

    # ---- Calibration plot on all-test (optional)
    try:
        y_proba_uncal = stacker_raw.predict_proba(X_te_all_sc_df)[:, 1]
        y_proba_cal   = stacker.predict_proba(X_te_all_sc_df)[:, 1]
        if has_both:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal, y_proba_cal],
                clf_names=["Uncalibrated", "Calibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration – {name_target_class}:{name}"
            )
        else:
            _ = MLTraining.plot_calibration_curve(
                y_te_all, [y_proba_uncal],
                clf_names=["Uncalibrated"],
                n_bins=15, strategy="quantile",
                title=f"Calibration (uncal only) – {name_target_class}:{name}"
            )
        plt.tight_layout()
        plt.show()
    except Exception as e:
        print(f"[WARN] Skipped calibration plot for {name}: {e}")

    # ---- Save per-class bundle (model + scaler + columns)
    save_subdir = models_dir / f"{name_target_class}_{name}"
    save_subdir.mkdir(parents=True, exist_ok=True)

    MLTraining.save_models({"Stacked": stacker}, out_dir=save_subdir, tag=f"{name_target_class}_{name}")
    joblib.dump(cols_train, save_subdir / "feature_names.joblib")

    bundle = {
        "atlas": "Luecken",
        "depth": name_target_class,
        "label": celltype,
        "model": stacker,          # CalibratedClassifierCV(StackingClassifier) or StackingClassifier
        "columns": cols_train,     # exact panel names, panel order
        "scaler": scaler,          # per-head scaler
        "panel_name": "TotalSeqD_Heme_Oncology_CAT399906",
    }
    bundle_path = save_subdir / f"{name_target_class}_{name}_bundle.joblib"
    joblib.dump(bundle, bundle_path)
    print(f"[SAVE] Wrote bundle with columns+scaler to {bundle_path}")

    # ---- Binary metrics on the class-specific test slice
    try:
        m = MLTraining.evaluate_classifier(stacker, X_te_sc_df, y_te, plot_cm=False)
        m.update(celltype=celltype)
        metrics_log.append(m)
        print(f"\n{celltype}\n", m.get("report", ""))
    except Exception as e:
        print(f"[WARN] Binary metrics for {name} skipped: {e}")

    # ---- Store OvR probs for multiclass calibration (columns order = class_names)
    P_cal[:, k] = stacker.predict_proba(X_cal_sc_df)[:, 1]
    P_te[:,  k] = stacker.predict_proba(X_te_all_sc_df)[:, 1]

# ============================= MULTICLASS CALIBRATION =============================

print("\nFitting multiclass TemperatureScaling on CAL split (excluded classes masked out)...")

# Guards: ensure probs are in [0,1]
if (P_cal < 0).any() or (P_cal > 1).any():
    raise ValueError("P_cal must be probabilities in [0,1].")
if (P_te < 0).any() or (P_te > 1).any():
    raise ValueError("P_te must be probabilities in [0,1].")

ts_cal = TemperatureScaling()
# Fit only on CAL rows whose true label is one of the included classes
ts_cal.fit(P_cal[mask_cal_mc.values, :], y_cal_multiclass)
P_te_mc = ts_cal.transform(P_te)

# Ensure calibrated probs shape (K)
P_te_mc = np.asarray(P_te_mc)
if P_te_mc.ndim == 1:
    P_te_mc = P_te_mc.reshape(-1, 1)
if P_te_mc.shape[1] == 1 and K == 2:
    P_te_mc = np.hstack([1.0 - P_te_mc, P_te_mc])
elif P_te_mc.shape[1] != K:
    # Fallback: normalize OvR sums
    row_sums = P_te.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0.0] = 1.0
    P_te_mc = P_te / row_sums
    print(f"[WARN] TemperatureScaling returned shape {P_te_mc.shape}; fell back to sum-normalized OvR probs.")

# Persist multiclass temp scaler + INCLUDED class names
joblib.dump(ts_cal, models_dir / f"{name_target_class}_multiclass_temp_scaler.joblib")
(pd.Series(class_names, name="class_name")
   .to_csv(models_dir / f"{name_target_class}_class_names.csv", index=False))

# ============================= PROBS COMPARISON & METRICS =============================

# Evaluate & save on TEST rows whose true label is an INCLUDED class
test_index_masked = Luecken_data_Test_Sub.index[mask_test_mc.values]

probs_raw_df = pd.DataFrame(P_te[mask_test_mc.values, :],    index=test_index_masked,
                            columns=[f"raw_{c}" for c in class_names])
probs_mc_df  = pd.DataFrame(P_te_mc[mask_test_mc.values, :], index=test_index_masked,
                            columns=[f"mc_{c}"  for c in class_names])

probs_compare = pd.concat([probs_raw_df, probs_mc_df], axis=1)
probs_compare["true_label"]    = Luecken_data_Test.loc[mask_test_mc, "Celltype"].values
probs_compare["pred_raw"]      = P_te[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_mc"]       = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
probs_compare["pred_raw_name"] = [class_names[i] for i in probs_compare["pred_raw"].values]
probs_compare["pred_mc_name"]  = [class_names[i] for i in probs_compare["pred_mc"].values]

print("\nPreview of probabilities BEFORE (raw OvR) vs AFTER (multiclass TS) [included classes only]:")
print(probs_compare.head(10).to_string())

probs_compare_path = models_dir / f"{name_target_class}_probabilities_before_after_TEST_included.csv"
probs_compare.to_csv(probs_compare_path, index=True)
print(f"\nSaved probabilities comparison to: {probs_compare_path}")

# Multiclass evaluation on the masked subset
y_pred_mc = P_te_mc[mask_test_mc.values, :].argmax(axis=1)
print("\nMulticlass classification report (TEST, excluded classes removed):")
print(classification_report(y_test_multiclass, y_pred_mc, target_names=class_names, digits=3))

cm = confusion_matrix(y_test_multiclass, y_pred_mc, labels=range(K))
print("Confusion matrix (rows=true, cols=pred):\n", cm)

# Per-class binary head metrics CSV
metrics_df = pd.DataFrame.from_records(metrics_log)
MLTraining.append_metrics_csv(metrics_df, csv_path=Path(models_output) / "stacker_metrics.csv")

print("\nDone.")
