In [None]:
# %% [markdown]
# # CBC Multi-Output (v2) — HGB, HCT, RBC
#
# - Usa el dataset: cbc_synthetic_30000_enriched_v2.csv
# - Elimina features fuertemente correlacionadas (|corr| > 0.80) antes de entrenar
# - Entrena modelo multi-output y modelos single-output
# - Calcula SHAP para ambos casos
# - Usa nombres "bonitos" para features y outputs en tablas y gráficos
# - Mide tiempo y memoria de entrenamiento y SHAP (multi vs single-output)
#   - Memoria medida como TOTAL PEAK en MB (decimal), NO delta.

# %%
import os, json, time
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import shap
from memory_profiler import memory_usage  # para medir memoria (devuelve MiB)

from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import StandardScaler

from sklearn.metrics import (
    r2_score,
    mean_absolute_error,
    mean_squared_error,
    f1_score        
)

# Directorios base (ajusta BASE_DIR si lo necesitas)
BASE_DIR = Path(".")
OUT_DIR  = BASE_DIR / "cbc_multi_output_v2_outputs_time_memory_MB_05"
IMG_DIR  = OUT_DIR / "figs"
TAB_DIR  = OUT_DIR / "tables"

for d in [OUT_DIR, IMG_DIR, TAB_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def save_current_fig(path_no_ext: Path):
    path_no_ext = Path(path_no_ext)
    plt.savefig(path_no_ext.with_suffix(".png"), bbox_inches="tight", dpi=300)
    plt.savefig(path_no_ext.with_suffix(".pdf"), bbox_inches="tight", dpi=300)

def rmse_compat(y_true, y_pred):
    return np.sqrt(np.mean((y_true - y_pred) ** 2))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ============================
# NUEVO: Registro global runtime/memory en MB (decimal)
# ============================
runtime_records = []

def add_runtime_record(stage: str, detail: str, time_sec: float, mem_peak_mb: float, mem_start_mb: float, mem_end_mb: float):
    """
    stage: 'train_multi', 'train_single', 'shap_multi', 'shap_single'
    detail: por ejemplo 'multi-output' o el nombre del output ('y_hgb_gdl', etc.)
    time_sec: tiempo en segundos
    mem_peak_mb: memoria TOTAL pico en MB (decimal)
    mem_start_mb: memoria al inicio en MB
    mem_end_mb: memoria al final en MB
    """
    runtime_records.append({
        "stage": stage,
        "detail": detail,
        "time_seconds": float(time_sec),
        "memory_peak_MB": float(mem_peak_mb),
        "memory_start_MB": float(mem_start_mb),
        "memory_end_MB": float(mem_end_mb),
    })

def profile_stage_total_MB(fn, interval: float = 0.05):
    """
    Ejecuta fn() y mide:
      - tiempo total
      - memoria TOTAL (RSS) en MiB (según memory_profiler), convertida a MB (decimal)
      - reporta peak/start/end en MB

    Nota: memory_usage devuelve MiB. Conversión estricta:
      1 MiB = 1.048576 MB
    """
    start_t = time.perf_counter()
    mem_trace_mib, out = memory_usage((fn, (), {}), retval=True, interval=interval)
    end_t = time.perf_counter()

    mib_to_mb = 1.048576  # MB decimales
    mem_start_mb = float(mem_trace_mib[0] * mib_to_mb)
    mem_end_mb   = float(mem_trace_mib[-1] * mib_to_mb)
    mem_peak_mb  = float(max(mem_trace_mib) * mib_to_mb)

    return out, (end_t - start_t), mem_peak_mb, mem_start_mb, mem_end_mb


# %% [markdown]
# ## 1) Load dataset, drop highly correlated features, define display names

# %%
csv_path = BASE_DIR / "cbc_synthetic_30000_enriched_v2.csv"
df = pd.read_csv(csv_path)

print("Shape (raw):", df.shape)
print("Columns:", list(df.columns))

# Targets (outputs) crudos
target_cols = ["y_hgb_gdl", "y_hct_pct", "y_rbc_10^12_per_L"]

# Nombres bonitos para outputs
target_display_map = {
    "y_hgb_gdl": "HGB (g/dL)",
    "y_hct_pct": "HCT (%)",
    "y_rbc_10^12_per_L": "RBC (10^12/L)",
}
target_display_names = [target_display_map[t] for t in target_cols]

# Features crudas
feature_cols_raw = [c for c in df.columns if c not in target_cols]

# Mapeo crudo -> bonito para features (si alguna no está, se deja igual)
feature_display_map = {
    "age_years": "Age (years)",
    "sex_male": "Sex (male=1)",
    "bmi": "BMI",
    "iron_ugdl": "Iron (µg/dL)",
    "ferritin_ngml": "Ferritin (ng/mL)",
    "vitamin_d_ngml": "Vitamin D (ng/mL)",
    "folate_ngml": "Folate (ng/mL)",
    "vitamin_b12_pgml": "Vitamin B12 (pg/mL)",
    "crp_mgL": "CRP (mg/L)",
    "albumin_gdl": "Albumin (g/dL)",
    "creatinine_mgdl": "Creatinine (mg/dL)",
    "egfr_ml_min": "eGFR (mL/min)",
    "sbp_mmHg": "Systolic BP (mmHg)",
    "dbp_mmHg": "Diastolic BP (mmHg)",
    "heart_rate_bpm": "Heart rate (bpm)",
    "wbc_10^9_per_L": "WBC (10^9/L)",
    "plt_10^9_per_L": "Platelets (10^9/L)",
    "smoker": "Smoker (1 = yes)",
    "alcohol_units_per_week": "Alcohol (units/week)",
    "physical_activity_level": "Physical activity level",
}

# 1) Correlación entre features (crudos) y eliminación de |corr| > 0.80
corr_features = df[feature_cols_raw].corr()
corr_abs = corr_features.abs()

to_drop = set()
cols = feature_cols_raw
threshold = 0.80

for i in range(len(cols)):
    for j in range(i + 1, len(cols)):
        if corr_abs.iloc[i, j] > threshold:
            col_i = cols[i]
            col_j = cols[j]
            # Estrategia simple: nos quedamos con col_i y eliminamos col_j
            to_drop.add(col_j)

print("\nHighly correlated feature pairs (|corr| > 0.80):")
for i in range(len(cols)):
    for j in range(i + 1, len(cols)):
        if corr_abs.iloc[i, j] > threshold:
            print(f"{cols[i]}  <->  {cols[j]} : corr = {corr_features.iloc[i, j]:.3f}")

print("\nFeatures to drop due to high correlation:", to_drop)

# Features finales tras eliminar las muy correlacionadas
feature_cols = [c for c in feature_cols_raw if c not in to_drop]

print("\nNumber of features before:", len(feature_cols_raw))
print("Number of features after drop:", len(feature_cols))
print("Final feature list (kept):")
print(feature_cols)

# Nombres bonitos en el orden de feature_cols
display_feature_names = [feature_display_map.get(c, c) for c in feature_cols]

# Guardar info de features y nombres display
with open(OUT_DIR / "dropped_features_due_corr.json", "w") as f:
    json.dump({
        "threshold": threshold,
        "dropped_features": sorted(list(to_drop)),
        "final_features_raw": feature_cols,
        "final_features_display": display_feature_names
    }, f, indent=2)

# Descripción básica de salidas
print("\nTargets description:")
print(df[target_cols].describe())

# Correlación entre outputs
corr_outputs = df[target_cols].corr()
print("\nCorrelation among outputs:")
print(corr_outputs)

corr_outputs.to_csv(TAB_DIR / "corr_outputs_true.csv", index_label="target_raw")

plt.figure()
plt.imshow(corr_outputs.values, cmap="coolwarm", vmin=-1, vmax=1)
plt.colorbar()
plt.xticks(range(len(target_cols)), target_display_names)
plt.yticks(range(len(target_cols)), target_display_names)
plt.title("Correlation — TRUE outputs (CBC v2)")
plt.tight_layout()
save_current_fig(IMG_DIR / "corr_heatmap_true_outputs_cbc_v2")
plt.show()


# %% [markdown]
# ## 2) Train/Val/Test split and scaling

# %%
X = df[feature_cols].values.astype(np.float32)
Y = df[target_cols].values.astype(np.float32)

print("X shape:", X.shape, "Y shape:", Y.shape)

X_train, X_temp, Y_train, Y_temp = train_test_split(
    X, Y, test_size=0.2, random_state=42
)
X_val, X_test, Y_val, Y_test = train_test_split(
    X_temp, Y_temp, test_size=0.5, random_state=42
)

print("Train:", X_train.shape, "Val:", X_val.shape, "Test:", X_test.shape)

scaler_X = StandardScaler()
scaler_Y = StandardScaler()

X_train_scaled = scaler_X.fit_transform(X_train)
X_val_scaled   = scaler_X.transform(X_val)
X_test_scaled  = scaler_X.transform(X_test)

Y_train_scaled = scaler_Y.fit_transform(Y_train)
Y_val_scaled   = scaler_Y.transform(Y_val)
Y_test_scaled  = scaler_Y.transform(Y_test)

scaler_info = {
    "feature_cols_raw": feature_cols,
    "feature_cols_display": display_feature_names,
    "target_cols_raw": target_cols,
    "target_cols_display": target_display_names,
    "X_mean": scaler_X.mean_.tolist(),
    "X_scale": scaler_X.scale_.tolist(),
    "Y_mean": scaler_Y.mean_.tolist(),
    "Y_scale": scaler_Y.scale_.tolist(),
}
with open(OUT_DIR / "scalers_cbc_v2.json", "w") as f:
    json.dump(scaler_info, f, indent=2)

# Para SHAP trabajamos con nombres crudos (orden) + nombres display para labels
clinical_feature_names = feature_cols          # crudos
clinical_feature_names_display = display_feature_names  # bonitos


# %% [markdown]
# ## 3) Multi-output MLP definition

# %%
class MLPRegressor(nn.Module):
    def __init__(self, in_dim, out_dim, hidden=(128, 64), dropout=0.1):
        super().__init__()
        layers = []
        last_dim = in_dim
        for h in hidden:
            layers.append(nn.Linear(last_dim, h))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            last_dim = h
        layers.append(nn.Linear(last_dim, out_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

def batches(X_t, Y_t, batch_size=256):
    n = X_t.shape[0]
    idx = torch.randperm(n, device=X_t.device)
    for i in range(0, n, batch_size):
        j = idx[i:i+batch_size]
        yield X_t[j], Y_t[j]


# %% [markdown]
# ## 4) Train multi-output model

# %%
in_dim  = X_train_scaled.shape[1]
out_dim = Y_train_scaled.shape[1]

multi_model = MLPRegressor(in_dim, out_dim, hidden=(128, 64), dropout=0.1).to(device)
criterion  = nn.MSELoss()
optimizer  = optim.Adam(multi_model.parameters(), lr=1e-3, weight_decay=1e-4)

Xtr_t = torch.tensor(X_train_scaled, dtype=torch.float32, device=device)
Ytr_t = torch.tensor(Y_train_scaled, dtype=torch.float32, device=device)
Xva_t = torch.tensor(X_val_scaled,   dtype=torch.float32, device=device)
Yva_t = torch.tensor(Y_val_scaled,   dtype=torch.float32, device=device)

max_epochs = 200
patience   = 20

def _train_multi_core():
    """
    Núcleo de entrenamiento del modelo multi-output.
    Devuelve el historial de pérdidas.
    """
    history = {"train_loss": [], "val_loss": []}
    best_val = float("inf")
    pat_counter = 0

    for epoch in range(max_epochs):
        multi_model.train()
        optimizer.zero_grad()
        preds_tr = multi_model(Xtr_t)
        loss_tr = criterion(preds_tr, Ytr_t)
        loss_tr.backward()
        optimizer.step()

        multi_model.eval()
        with torch.no_grad():
            preds_va = multi_model(Xva_t)
            loss_va = criterion(preds_va, Yva_t).item()
            loss_tr_val = loss_tr.item()

        history["train_loss"].append(loss_tr_val)
        history["val_loss"].append(loss_va)

        print(
            f"Epoch {epoch+1:03d}/{max_epochs} | "
            f"train_loss={loss_tr_val:.6f}, val_loss={loss_va:.6f}"
        )

        if loss_va < best_val - 1e-8:
            best_val = loss_va
            pat_counter = 0
            torch.save(multi_model.state_dict(), OUT_DIR / "best_multi_model_cbc_v2.pt")
        else:
            pat_counter += 1
            if pat_counter >= patience:
                print("Early stopping at epoch", epoch + 1)
                break

    # Cargamos siempre el mejor modelo guardado
    if (OUT_DIR / "best_multi_model_cbc_v2.pt").exists():
        multi_model.load_state_dict(
            torch.load(OUT_DIR / "best_multi_model_cbc_v2.pt", map_location=device)
        )

    return history

# --- MEDICIÓN DE TIEMPO Y MEMORIA TOTAL (PEAK) EN MB ---
history, train_time_multi, mem_peak_mb, mem_start_mb, mem_end_mb = profile_stage_total_MB(_train_multi_core, interval=0.05)

print(f"\n[Multi-output] Training time: {train_time_multi:.3f} s")
print(f"[Multi-output] Training memory PEAK: {mem_peak_mb:.2f} MB (start={mem_start_mb:.2f}, end={mem_end_mb:.2f})")

add_runtime_record(
    stage="train_multi",
    detail="multi-output",
    time_sec=train_time_multi,
    mem_peak_mb=mem_peak_mb,
    mem_start_mb=mem_start_mb,
    mem_end_mb=mem_end_mb,
)

# Guardamos historial de entrenamiento
pd.DataFrame(history).to_csv(TAB_DIR / "training_history_multi_cbc_v2.csv", index=False)

# --- PLOT: Training vs Validation (colores fijos) ---
plt.figure(figsize=(10, 5))
ax = plt.gca()
ax.set_facecolor("#f5f5f5")
plt.plot(history["train_loss"], label="Training",  color="#1f77b4")  # azul
plt.plot(history["val_loss"],   label="Validation", color="#ff7f0e")  # naranjo
plt.legend(fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.xlim([0, len(history["train_loss"])])
plt.tight_layout()
save_current_fig(IMG_DIR / "training_curve_multi_cbc_v2")
plt.show()


# %% [markdown]
# ## 5) Evaluation of multi-output model

# %%
multi_model.eval()

def eval_split_multi(split_name, X_scaled, Y_true, Y_df):
    with torch.no_grad():
        preds_std = multi_model(
            torch.tensor(X_scaled, dtype=torch.float32, device=device)
        ).cpu().numpy()

    preds = preds_std * scaler_Y.scale_ + scaler_Y.mean_

    metrics_rows = []
    for k, out_name in enumerate(target_cols):
        y_true_k = Y_df[out_name].values
        y_pred_k = preds[:, k]

        r2   = r2_score(y_true_k, y_pred_k)
        rmse = rmse_compat(y_true_k, y_pred_k)
        mae  = mean_absolute_error(y_true_k, y_pred_k)

        median_thr = np.median(Y_train[:, k])
        y_true_bin = (y_true_k > median_thr).astype(int)
        y_pred_bin = (y_pred_k > median_thr).astype(int)
        f1 = f1_score(y_true_bin, y_pred_bin)

        metrics_rows.append({
            "split": split_name,
            "target_raw": out_name,
            "target_display": target_display_map.get(out_name, out_name),
            "R2": r2,
            "RMSE": rmse,
            "MAE": mae,
            "F1_median": f1,
        })

    return pd.DataFrame(metrics_rows), preds

Y_train_df = pd.DataFrame(Y_train, columns=target_cols)
Y_val_df   = pd.DataFrame(Y_val,   columns=target_cols)
Y_test_df  = pd.DataFrame(Y_test,  columns=target_cols)

m_tr, preds_tr   = eval_split_multi("train", X_train_scaled, Y_train, Y_train_df)
m_val, preds_val = eval_split_multi("val",   X_val_scaled,   Y_val,   Y_val_df)
m_te, preds_te   = eval_split_multi("test",  X_test_scaled,  Y_test,  Y_test_df)

metrics_multi = pd.concat([m_tr, m_val, m_te], ignore_index=True)
metrics_multi.to_csv(TAB_DIR / "metrics_multi_cbc_v2.csv", index=False)
print(metrics_multi)

residuals_test = Y_test - preds_te
res_corr = pd.DataFrame(residuals_test, columns=target_cols).corr()
print("\nResiduals correlation (test):")
print(res_corr)
res_corr.to_csv(TAB_DIR / "residuals_corr_multi_cbc_v2_test.csv", index_label="target_raw")

plt.figure()
plt.imshow(res_corr.values, cmap="coolwarm", vmin=-1, vmax=1)
plt.colorbar()
plt.xticks(range(len(target_cols)), target_display_names)
plt.yticks(range(len(target_cols)), target_display_names)
plt.title("Residuals correlation — MULTI-OUTPUT (test, v2)")
plt.tight_layout()
save_current_fig(IMG_DIR / "residuals_corr_multi_cbc_v2_test")
plt.show()


# %% [markdown]
# ## 6) SHAP — background & explained sets

# %%
bg_size  = min(1000, X_train_scaled.shape[0])
exp_size = min(2000, X_test_scaled.shape[0])

rng = np.random.default_rng(0)
bg_idx  = rng.choice(X_train_scaled.shape[0], size=bg_size, replace=False)
exp_idx = rng.choice(X_test_scaled.shape[0],  size=exp_size, replace=False)

X_bg_scaled = X_train_scaled[bg_idx]
X_exp_scaled = X_test_scaled[exp_idx]

X_exp_np = X_exp_scaled.copy()
# DataFrame con columnas crudas (orden correcto para SHAP)
X_exp_view = pd.DataFrame(X_exp_np, columns=clinical_feature_names)

X_bg_scaled_t  = torch.tensor(X_bg_scaled,  dtype=torch.float32, device=device)
X_exp_scaled_t = torch.tensor(X_exp_scaled, dtype=torch.float32, device=device)

X_exp_unscaled = pd.DataFrame(
    scaler_X.inverse_transform(X_exp_np),
    columns=clinical_feature_names
)

print("Background shape:", X_bg_scaled.shape)
print("Explained shape:", X_exp_scaled.shape)
print("Raw feature names order:", clinical_feature_names)
print("Display feature names:", clinical_feature_names_display)


# %% [markdown]
# ## 7) Compute SHAP values for multi-output model (real feature display names)

# %%
multi_model.eval()

n_outputs  = len(target_cols)
n_features = len(clinical_feature_names)

def _compute_shap_multi_core():
    """
    Núcleo de cálculo SHAP para el modelo multi-output.
    Devuelve la lista/array de valores SHAP sin normalizar.
    """
    explainer_multi = shap.DeepExplainer(multi_model, X_bg_scaled_t)
    shap_values_raw = explainer_multi.shap_values(X_exp_scaled_t)
    return shap_values_raw

# --- MEDICIÓN DE TIEMPO Y MEMORIA TOTAL (PEAK) EN MB PARA SHAP multi ---
shap_values_raw, shap_time_multi, mem_peak_mb, mem_start_mb, mem_end_mb = profile_stage_total_MB(_compute_shap_multi_core, interval=0.05)

print(f"\n[SHAP multi-output] Time: {shap_time_multi:.3f} s")
print(f"[SHAP multi-output] Memory PEAK: {mem_peak_mb:.2f} MB (start={mem_start_mb:.2f}, end={mem_end_mb:.2f})")

add_runtime_record(
    stage="shap_multi",
    detail="multi-output",
    time_sec=shap_time_multi,
    mem_peak_mb=mem_peak_mb,
    mem_start_mb=mem_start_mb,
    mem_end_mb=mem_end_mb,
)

# Normalización y reestructuración de los SHAP values como antes
shap_values_per_output = []

if isinstance(shap_values_raw, (list, tuple)):
    if len(shap_values_raw) != n_outputs:
        raise ValueError(f"Expected {n_outputs} outputs in SHAP list, got {len(shap_values_raw)}")

    for k in range(n_outputs):
        arr = np.array(shap_values_raw[k])
        arr = np.squeeze(arr)
        if arr.ndim == 1:
            arr = arr.reshape(-1, n_features)
        if arr.shape[1] != n_features:
            raise ValueError(
                f"Unexpected SHAP shape {arr.shape} for output {target_cols[k]} "
                f"(expected second dim = {n_features})"
            )
        shap_values_per_output.append(arr)
else:
    arr = np.array(shap_values_raw)
    if arr.ndim != 3:
        raise ValueError(f"Unexpected SHAP array ndim={arr.ndim}, expected 3.")
    # Posibles convenciones: (n_samples, n_features, n_outputs) o (n_outputs, n_samples, n_features)
    if arr.shape[0] == n_outputs and arr.shape[2] == n_features:
        # (n_outputs, n_samples, n_features)
        for k in range(n_outputs):
            shap_values_per_output.append(arr[k])
    elif arr.shape[2] == n_outputs and arr.shape[1] == n_features:
        # (n_samples, n_features, n_outputs)
        for k in range(n_outputs):
            shap_values_per_output.append(arr[:, :, k])
    else:
        raise ValueError(
            f"Cannot interpret SHAP array shape {arr.shape} "
            f"for n_outputs={n_outputs}, n_features={n_features}"
        )

print("Multi-output SHAP shapes per output:")
for k, out_name in enumerate(target_cols):
    print(out_name, shap_values_per_output[k].shape)

# Guardar en .npy por salida
for i, out_name in enumerate(target_cols):
    np.save(OUT_DIR / f"shap_values_multi_{out_name}_cbc_v2.npy", shap_values_per_output[i])

with open(OUT_DIR / "shap_multi_meta_cbc_v2.json", "w") as f:
    json.dump({
        "mode": "multi-output",
        "method": "DeepExplainer",
        "background_size": int(X_bg_scaled.shape[0]),
        "explain_size": int(X_exp_scaled.shape[0]),
        "targets_raw": target_cols,
        "targets_display": target_display_names,
        "n_features_used": int(X_exp_np.shape[1]),
        "feature_names_raw": clinical_feature_names,
        "feature_names_display": clinical_feature_names_display,
    }, f, indent=2)

print("Saved multi-output SHAP arrays (.npy) and metadata (.json) for v2.")


# %% [markdown]
# ## 8) SHAP multi-output — bar, beeswarm, heatmap, dependence (display names)

# %%
n_outputs = len(target_cols)
n_features = len(clinical_feature_names_display)

# 1) Calcular importancia media absoluta por feature y por salida
mean_abs = {}
for i, out_name in enumerate(target_cols):
    vals = shap_values_per_output[i]  # shape: (n_samples, n_features)
    s = pd.Series(
        np.mean(np.abs(vals), axis=0),
        index=clinical_feature_names_display,
        name=out_name,
    )
    mean_abs[out_name] = s

# Guardamos una tabla con las medias absolutas (todas las salidas)
mean_abs_df = pd.DataFrame(mean_abs)
mean_abs_df.to_csv(
    TAB_DIR / "mean_abs_shap_multi_cbc_v2.csv",
    index_label="feature_display",
)

# 2) Bar & beeswarm (Top-5) por salida
for i, out_name in enumerate(target_cols):
    out_display = target_display_map.get(out_name, out_name)

    s = mean_abs[out_name].sort_values(ascending=False)
    #top5 = s.head(5)
    top5=s.head(10)
    # --- BAR ---
    plt.figure(figsize=(8, 5))
    ax = plt.gca()
    ax.set_facecolor("white")
    top5[::-1].plot(kind="barh", color="#1f77b4")
    plt.title(f"Top-5 mean |SHAP| — Multi-output — {out_display}")
    plt.xlabel("Mean |SHAP value|")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_bar_multi_{out_name}_cbc_v2")
    plt.show()

    # --- BEESWARM ---
    plt.figure(figsize=(8, 6))
    shap.summary_plot(
        shap_values_per_output[i],
        features=X_exp_unscaled,
        feature_names=clinical_feature_names_display,
        max_display=6,
        show=False,
    )
    #plt.title(f"Beeswarm Top-5 — Multi-output — {out_display}")
    plt.title(f"Nonlinear Multi-output model — {out_display}")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_beeswarm_multi_{out_name}_cbc_v2")
    plt.show()

print("Saved Top-5 bar & beeswarm plots (multi-output, v2) with display names.")

# 3) Heatmap: Top-10 features por importancia total (sumando salidas)
combined = pd.DataFrame(
    {out_name: mean_abs[out_name].values for out_name in target_cols},
    index=clinical_feature_names_display,
)

overall = combined.abs().sum(axis=1).sort_values(ascending=False)
topK = 10
top_features = overall.head(topK).index.tolist()

heat_data = combined.loc[top_features].T  # shape: (n_outputs, topK)

plt.figure(figsize=(10, 6))
plt.imshow(heat_data.values, aspect="auto", cmap="coolwarm")
plt.colorbar(label="Mean |SHAP value|")
plt.xticks(
    ticks=np.arange(len(top_features)),
    labels=top_features,
    rotation=45,
    ha="right",
)
plt.yticks(
    ticks=np.arange(len(target_cols)),
    labels=target_display_names,
)
plt.title("Top-10 features — Mean |SHAP| (multi-output, CBC v2)")
plt.tight_layout()
save_current_fig(IMG_DIR / "shap_heatmap_multi_top10_cbc_v2")
plt.show()

print("Saved heatmap for Top-10 features (multi-output, v2).")

# 4) Dependence plots: feature Top-1 por salida
top1_idx = {}
for out_name in target_cols:
    s = mean_abs[out_name].sort_values(ascending=False)
    feat_display = s.index[0]

    name_to_idx = {disp: j for j, disp in enumerate(clinical_feature_names_display)}
    feat_idx = name_to_idx[feat_display]
    top1_idx[out_name] = (feat_idx, feat_display)

print("Top-1 features (multi-output, v2):", {k: v[1] for k, v in top1_idx.items()})

for out_name in target_cols:
    out_display = target_display_map.get(out_name, out_name)
    feat_idx, feat_display = top1_idx[out_name]

    plt.figure()
    shap.dependence_plot(
        ind=feat_idx,
        shap_values=shap_values_per_output[target_cols.index(out_name)],
        features=X_exp_unscaled,
        feature_names=clinical_feature_names_display,
        show=False,
    )
    plt.title(f"SHAP dependence — Multi-output — {out_display} vs {feat_display}")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_dependence_multi_{out_name}_feat{feat_idx}_cbc_v2")
    plt.show()

print("Saved dependence plots (multi-output, v2) with display names.")


# %% [markdown]
# ## 9) Single-output models (training & evaluation)

# %%
single_models = {}
single_histories = {}

def train_single_output_model(k, out_name):
    """
    Entrena un modelo single-output para el índice k y nombre out_name.
    Además mide tiempo y memoria TOTAL (pico) en MB.
    """
    print(f"\n=== Training single-output model for {out_name} ===")

    in_dim = X_train_scaled.shape[1]
    model_k = MLPRegressor(in_dim, 1, hidden=(128, 64), dropout=0.1).to(device)

    criterion_k = nn.MSELoss()
    optimizer_k = optim.Adam(model_k.parameters(), lr=1e-3, weight_decay=1e-4)

    ytr_k = Y_train_scaled[:, k]
    yva_k = Y_val_scaled[:, k]

    Ytr_k_t = torch.tensor(ytr_k.reshape(-1, 1), dtype=torch.float32, device=device)
    Yva_k_t = torch.tensor(yva_k.reshape(-1, 1), dtype=torch.float32, device=device)

    Xtr_t_loc = torch.tensor(X_train_scaled, dtype=torch.float32, device=device)
    Xva_t_loc = torch.tensor(X_val_scaled,   dtype=torch.float32, device=device)

    max_epochs_k = 200
    patience_k   = 20

    def _train_single_core():
        history_k = {"train_loss": [], "val_loss": []}
        best_val = float("inf")
        pat_counter = 0

        for epoch in range(max_epochs_k):
            model_k.train()
            optimizer_k.zero_grad()
            preds_tr_k = model_k(Xtr_t_loc)
            loss_tr_k = criterion_k(preds_tr_k, Ytr_k_t)
            loss_tr_k.backward()
            optimizer_k.step()

            model_k.eval()
            with torch.no_grad():
                preds_va_k = model_k(Xva_t_loc)
                loss_va_k = criterion_k(preds_va_k, Yva_k_t).item()
                loss_tr_k_val = loss_tr_k.item()

            history_k["train_loss"].append(loss_tr_k_val)
            history_k["val_loss"].append(loss_va_k)

            print(
                f"[{out_name}] Epoch {epoch+1:03d}/{max_epochs_k} | "
                f"train_loss={loss_tr_k_val:.6f}, val_loss={loss_va_k:.6f}"
            )

            if loss_va_k < best_val - 1e-8:
                best_val = loss_va_k
                pat_counter = 0
                torch.save(model_k.state_dict(), OUT_DIR / f"best_single_model_{out_name}_cbc_v2.pt")
            else:
                pat_counter += 1
                if pat_counter >= patience_k:
                    print(f"Early stopping (single {out_name}) at epoch", epoch + 1)
                    break

        # Cargar mejor modelo
        if (OUT_DIR / f"best_single_model_{out_name}_cbc_v2.pt").exists():
            model_k.load_state_dict(
                torch.load(OUT_DIR / f"best_single_model_{out_name}_cbc_v2.pt", map_location=device)
            )

        return history_k

    # --- MEDICIÓN DE TIEMPO Y MEMORIA TOTAL (PEAK) EN MB ---
    history_k, train_time_single, mem_peak_mb, mem_start_mb, mem_end_mb = profile_stage_total_MB(_train_single_core, interval=0.05)

    print(f"[Single {out_name}] Training time: {train_time_single:.3f} s")
    print(f"[Single {out_name}] Training memory PEAK: {mem_peak_mb:.2f} MB (start={mem_start_mb:.2f}, end={mem_end_mb:.2f})")

    add_runtime_record(
        stage="train_single",
        detail=out_name,
        time_sec=train_time_single,
        mem_peak_mb=mem_peak_mb,
        mem_start_mb=mem_start_mb,
        mem_end_mb=mem_end_mb,
    )

    single_models[out_name] = model_k
    single_histories[out_name] = history_k

    # Guardar historial y curva
    hist_df = pd.DataFrame(history_k)
    hist_df.to_csv(TAB_DIR / f"training_history_single_{out_name}_cbc_v2.csv", index=False)

    plt.figure(figsize=(10, 5))
    ax = plt.gca()
    ax.set_facecolor("#f5f5f5")
    plt.plot(history_k["train_loss"], label="Training",  color="#1f77b4")
    plt.plot(history_k["val_loss"],   label="Validation", color="#ff7f0e")
    plt.legend(fontsize=14)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.xlim([0, len(history_k["train_loss"])])
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"training_curve_single_{out_name}_cbc_v2")
    plt.show()


# Entrenamos los 3 modelos single-output
for k, out_name in enumerate(target_cols):
    train_single_output_model(k, out_name)


# %% [markdown]
# ## 10) SHAP for single-output models (with display feature names)

# %%
def normalize_single_output_shap(shap_raw, n_features):
    if isinstance(shap_raw, (list, tuple)):
        if len(shap_raw) == 0:
            raise ValueError("Empty SHAP result for single-output model.")
        arr = np.array(shap_raw[0])
    else:
        arr = np.array(shap_raw)
    arr = np.squeeze(arr)
    if arr.ndim == 1:
        arr = arr.reshape(-1, n_features)
    if arr.shape[1] != n_features:
        raise ValueError(f"Unexpected SHAP shape {arr.shape}, expected {n_features} features")
    return arr

n_features = len(clinical_feature_names)
single_shap_values = {}

for k, out_name in enumerate(target_cols):
    print(f"\n=== Computing SHAP for single-output model: {out_name} (v2) ===")
    model_k = single_models[out_name]
    model_k.eval()

    def _compute_shap_single_core():
        explainer_k = shap.DeepExplainer(model_k, X_bg_scaled_t)
        shap_raw = explainer_k.shap_values(X_exp_scaled_t)
        shap_norm = normalize_single_output_shap(shap_raw, n_features=n_features)
        return shap_norm

    # --- MEDICIÓN DE TIEMPO Y MEMORIA TOTAL (PEAK) EN MB ---
    shap_arr, shap_time_single, mem_peak_mb, mem_start_mb, mem_end_mb = profile_stage_total_MB(_compute_shap_single_core, interval=0.05)

    print(f"[SHAP single {out_name}] Time: {shap_time_single:.3f} s")
    print(f"[SHAP single {out_name}] Memory PEAK: {mem_peak_mb:.2f} MB (start={mem_start_mb:.2f}, end={mem_end_mb:.2f})")

    add_runtime_record(
        stage="shap_single",
        detail=out_name,
        time_sec=shap_time_single,
        mem_peak_mb=mem_peak_mb,
        mem_start_mb=mem_start_mb,
        mem_end_mb=mem_end_mb,
    )

    # Guardamos en memoria y a disco
    single_shap_values[out_name] = shap_arr
    np.save(OUT_DIR / f"shap_values_single_{out_name}_cbc_v2.npy", shap_arr)

with open(OUT_DIR / "shap_single_meta_cbc_v2.json", "w") as f:
    json.dump({
        "mode": "single-output-per-target",
        "method": "DeepExplainer",
        "background_size": int(X_bg_scaled_t.shape[0]),
        "explain_size": int(X_exp_scaled_t.shape[0]),
        "targets_raw": target_cols,
        "targets_display": target_display_names,
        "n_features_used": int(X_exp_np.shape[1]),
        "feature_names_raw": clinical_feature_names,
        "feature_names_display": clinical_feature_names_display,
    }, f, indent=2)

print("Saved single-output SHAP arrays (.npy) and metadata (.json) for v2.")


# %% [markdown]
# ## 11) SHAP single-output — mean |SHAP|, bar & beeswarm (display names)

# %%
mean_abs_single = {}
for out_name in target_cols:
    v = np.abs(single_shap_values[out_name]).mean(axis=0)
    s = pd.Series(v, index=clinical_feature_names_display).sort_values(ascending=False)
    mean_abs_single[out_name] = s
    s.to_csv(
        TAB_DIR / f"mean_abs_shap_single_{out_name}_cbc_v2.csv",
        header=["mean_abs_shap"],
        index_label="feature_display",
    )

combined_single = pd.DataFrame(
    {out_name: mean_abs_single[out_name].values for out_name in target_cols},
    index=clinical_feature_names_display,
)
combined_single.to_csv(
    TAB_DIR / "mean_abs_shap_single_all_cbc_v2.csv",
    index_label="feature_display",
)

print("Saved mean |SHAP| tables for single-output models (v2).")

for out_name in target_cols:
    out_display = target_display_map.get(out_name, out_name)
    s = mean_abs_single[out_name].sort_values(ascending=False)
    #top5 = s.head(5)
    top5=s.head(10)

    plt.figure(figsize=(8, 5))
    ax = plt.gca()
    ax.set_facecolor("white")
    top5[::-1].plot(kind="barh", color="#1f77b4")
    plt.title(f"Top-5 mean |SHAP| — Single-output — {out_display}")
    plt.xlabel("Mean |SHAP value|")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_bar_single_{out_name}_cbc_v2")
    plt.show()

    plt.figure(figsize=(8, 6))
    shap.summary_plot(
        single_shap_values[out_name],
        X_exp_view,
        feature_names=clinical_feature_names_display,
        max_display=5,
        show=False,
    )
    plt.title(f"SHAP beeswarm — Single-output — {out_display}")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_beeswarm_top5_single_{out_name}_cbc_v2")
    plt.show()

print("Saved Top-5 bar & beeswarm plots for single-output models (v2) with display names.")


# %% [markdown]
# ## 12) SHAP single-output — heatmap, dependence, cross-output SHAP corr (display)

# %%
overall_single = combined_single.abs().sum(axis=1).sort_values(ascending=False)
topK_single = 10
top_features_single = overall_single.head(topK_single).index.tolist()

heat_single = combined_single.loc[top_features_single, :].values

plt.figure(figsize=(8, 6))
plt.imshow(heat_single, aspect="auto")
plt.colorbar()
plt.xticks(range(len(target_display_names)), target_display_names, rotation=0)
plt.yticks(range(len(top_features_single)), top_features_single)
plt.title("Mean |SHAP| (single-output, v2) — Top features vs targets")
plt.tight_layout()
save_current_fig(IMG_DIR / "shap_heatmap_single_top_features_cbc_v2")
plt.show()

top1_single_idx = {}
for out_name in target_cols:
    s = mean_abs_single[out_name]
    name_to_idx = {disp: j for j, disp in enumerate(clinical_feature_names_display)}
    feat_display = s.index[0]
    feat_idx = name_to_idx[feat_display]
    top1_single_idx[out_name] = (feat_idx, feat_display)

print("Top-1 features (single-output, v2):", {k: v[1] for k, v in top1_single_idx.items()})

for out_name in target_cols:
    out_display = target_display_map.get(out_name, out_name)
    feat_idx, feat_display = top1_single_idx[out_name]

    plt.figure()
    shap.dependence_plot(
        ind=feat_idx,
        shap_values=single_shap_values[out_name],
        features=X_exp_unscaled,
        feature_names=clinical_feature_names_display,
        show=False
    )
    plt.title(f"SHAP dependence — Single-output — {out_display} vs {feat_display}")
    save_current_fig(IMG_DIR / f"shap_dependence_single_{out_name}_feat{feat_idx}_cbc_v2")
    plt.show()

from itertools import combinations

corr_rows_single = []
for outA, outB in combinations(target_cols, 2):
    svA = single_shap_values[outA]
    svB = single_shap_values[outB]

    mean_shap_A = pd.Series(svA.mean(axis=0), index=clinical_feature_names_display)
    mean_shap_B = pd.Series(svB.mean(axis=0), index=clinical_feature_names_display)
    c = mean_shap_A.corr(mean_shap_B)

    corr_rows_single.append({
        "output_A_raw": outA,
        "output_A_display": target_display_map.get(outA, outA),
        "output_B_raw": outB,
        "output_B_display": target_display_map.get(outB, outB),
        "featurewise_SHAP_corr": float(c)
    })

corr_df_single = pd.DataFrame(corr_rows_single)
corr_df_single.to_csv(
    TAB_DIR / "cross_output_shap_single_featurewise_corr_cbc_v2.csv",
    index=False
)
print(corr_df_single)
print("Saved cross-output SHAP correlation CSV for single-output models (v2) with display names.")


# %% [markdown]
# ## 13) Output correlations: TRUE vs multi-output vs single-output (display names)

# %%
print("=== Computing correlations among HGB, HCT, RBC (v2) ===")

true_corr = Y_test_df[target_cols].corr()

multi_model.eval()
with torch.no_grad():
    preds_multi_std = multi_model(
        torch.tensor(X_test_scaled, dtype=torch.float32, device=device)
    ).cpu().numpy()
preds_multi = preds_multi_std * scaler_Y.scale_ + scaler_Y.mean_
preds_multi_df = pd.DataFrame(preds_multi, columns=target_cols)
multi_corr = preds_multi_df.corr()

preds_single_dict = {}
for k, out_name in enumerate(target_cols):
    model_k = single_models[out_name]
    model_k.eval()
    with torch.no_grad():
        pred_std = model_k(
            torch.tensor(X_test_scaled, dtype=torch.float32, device=device)
        ).cpu().numpy().reshape(-1)
    preds_single_dict[out_name] = pred_std * scaler_Y.scale_[k] + scaler_Y.mean_[k]

preds_single_df = pd.DataFrame(preds_single_dict)
single_corr = preds_single_df.corr()

true_corr.to_csv(TAB_DIR / "corr_true_outputs_cbc_v2_test.csv", index_label="target_raw")
multi_corr.to_csv(TAB_DIR / "corr_multi_output_preds_cbc_v2_test.csv", index_label="target_raw")
single_corr.to_csv(TAB_DIR / "corr_single_output_preds_cbc_v2_test.csv", index_label="target_raw")

print("True correlation:\n", true_corr)
print("\nMulti-output preds correlation:\n", multi_corr)
print("\nSingle-output preds correlation:\n", single_corr)

def plot_corr_heatmap(mat, title, filename):
    plt.figure(figsize=(5, 4))
    plt.imshow(mat, cmap="coolwarm", vmin=-1, vmax=1)
    plt.colorbar()
    plt.xticks(range(len(target_display_names)), target_display_names)
    plt.yticks(range(len(target_display_names)), target_display_names)
    plt.title(title)
    plt.tight_layout()
    save_current_fig(IMG_DIR / filename)
    plt.show()

plot_corr_heatmap(true_corr.values,
                  "Correlation — TRUE outputs (test, v2)",
                  "corr_heatmap_true_outputs_cbc_v2_test")

plot_corr_heatmap(multi_corr.values,
                  "Correlation — MULTI-OUTPUT predictions (test, v2)",
                  "corr_heatmap_multi_output_preds_cbc_v2_test")

plot_corr_heatmap(single_corr.values,
                  "Correlation — SINGLE-OUTPUT predictions (test, v2)",
                  "corr_heatmap_single_output_preds_cbc_v2_test")

print("\n=== Done! Correlation matrices and heatmaps saved successfully (v2). ===")


# %% [markdown]
# ## 14) Save runtime & memory summary (MB)

# %%
runtime_df = pd.DataFrame(runtime_records)

csv_path = TAB_DIR / "runtime_memory_cbc_v2_MB.csv"
runtime_df.to_csv(csv_path, index=False)

print("\n=== Runtime & Memory summary saved ===")
print("Path:", csv_path)
print(runtime_df)


# %% [markdown]
# ## 15) Evaluation metrics for Multi-output and Single-output Models (Train / Val / Test)

# %%
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error, f1_score
import numpy as np
import pandas as pd

Y_train_real = Y_train
Y_val_real   = Y_val
Y_test_real  = Y_test

multi_model.eval()

def eval_multi_split(split_name, X_scaled, Y_real):
    with torch.no_grad():
        preds_scaled = multi_model(
            torch.tensor(X_scaled, dtype=torch.float32, device=device)
        ).cpu().numpy()

    preds_real = scaler_Y.inverse_transform(preds_scaled)

    rows = []
    for j, out_name in enumerate(target_cols):
        out_display = target_display_map.get(out_name, out_name)

        y_true = Y_real[:, j]
        y_pred = preds_real[:, j]

        r2   = r2_score(y_true, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mae  = mean_absolute_error(y_true, y_pred)

        median_thr = np.median(y_true)
        y_true_bin = (y_true > median_thr).astype(int)
        y_pred_bin = (y_pred > median_thr).astype(int)
        f1 = f1_score(y_true_bin, y_pred_bin)

        rows.append({
            "model_type": "multi-output",
            "split": split_name,
            "output": out_display,
            "R2": r2,
            "RMSE": rmse,
            "MAE": mae,
            "F1_median": f1,
        })
    return rows

metrics_multi_all = []
metrics_multi_all += eval_multi_split("train", X_train_scaled, Y_train_real)
metrics_multi_all += eval_multi_split("val",   X_val_scaled,   Y_val_real)
metrics_multi_all += eval_multi_split("test",  X_test_scaled,  Y_test_real)

df_metrics_multi_all = pd.DataFrame(metrics_multi_all)
df_metrics_multi_all.to_csv(
    TAB_DIR / "metrics_multi_output_all_splits_cbc_v2.csv",
    index=False,
)

print("Multi-output metrics (train/val/test):")
display(df_metrics_multi_all)


def eval_single_split(split_name, X_scaled, Y_real):
    rows = []
    for k, out_name in enumerate(target_cols):
        out_display = target_display_map.get(out_name, out_name)

        model_k = single_models[out_name]
        model_k.eval()

        with torch.no_grad():
            preds_k_scaled = model_k(
                torch.tensor(X_scaled, dtype=torch.float32, device=device)
            ).cpu().numpy()   # shape (n_samples, 1)

        repeated = np.repeat(preds_k_scaled, repeats=len(target_cols), axis=1)
        preds_k_real = scaler_Y.inverse_transform(repeated)[:, k]

        y_true_k = Y_real[:, k]

        r2_k   = r2_score(y_true_k, preds_k_real)
        rmse_k = np.sqrt(mean_squared_error(y_true_k, preds_k_real))
        mae_k  = mean_absolute_error(y_true_k, preds_k_real)

        median_thr = np.median(y_true_k)
        y_true_bin = (y_true_k > median_thr).astype(int)
        y_pred_bin = (preds_k_real > median_thr).astype(int)
        f1_k = f1_score(y_true_bin, y_pred_bin)

        rows.append({
            "model_type": "single-output",
            "split": split_name,
            "output": out_display,
            "R2": r2_k,
            "RMSE": rmse_k,
            "MAE": mae_k,
            "F1_median": f1_k,
        })
    return rows

metrics_single_all = []
metrics_single_all += eval_single_split("train", X_train_scaled, Y_train_real)
metrics_single_all += eval_single_split("val",   X_val_scaled,   Y_val_real)
metrics_single_all += eval_single_split("test",  X_test_scaled,  Y_test_real)

df_metrics_single_all = pd.DataFrame(metrics_single_all)
df_metrics_single_all.to_csv(
    TAB_DIR / "metrics_single_output_all_splits_cbc_v2.csv",
    index=False,
)

print("Single-output metrics (train/val/test):")
display(df_metrics_single_all)


# %% [markdown]
# ## 16) Redesigned Figures (with custom font sizes)
#
# - (A) Bar comparison (Multi vs Single) with consistent colors
# - (B) Raw SHAP side-by-side beeswarms (Multi vs Single)
# - Customizable font sizes for titles/axes/ticks/legend

# %%
# --- Safety checks: ensure required variables exist ---
required_vars = [
    "target_cols",
    "target_display_map",
    "clinical_feature_names_display",
    "X_exp_unscaled",
    "shap_values_per_output",
    "single_shap_values",
    "IMG_DIR",
    "save_current_fig",
]
missing = [v for v in required_vars if v not in globals()]
if missing:
    raise RuntimeError(f"Missing variables needed for redesigned figures: {missing}")

# -----------------------------
# User-configurable parameters
# -----------------------------
TOPK = 10  # number of features to display per output in redesigned plots

# Consistent colors across all outputs:
COLOR_MULTI  = "#1f77b4"  # blue
COLOR_SINGLE = "#ff7f0e"  # orange

# -----------------------------
# Font sizes (customize here)
# -----------------------------
FS_TITLE = 18
FS_LABEL = 16
FS_TICK  = 16
FS_LEG   = 13

# Optional: apply globally to matplotlib (affects all plots created after this)
plt.rcParams.update({
    "axes.titlesize": FS_TITLE,
    "axes.labelsize": FS_LABEL,
    "xtick.labelsize": FS_TICK,
    "ytick.labelsize": FS_TICK,
    "legend.fontsize": FS_LEG,
})

# Helper: mean(|SHAP|) as pandas.Series indexed by display feature names
def mean_abs_series(shap_array, feature_names_display):
    return pd.Series(np.mean(np.abs(shap_array), axis=0), index=feature_names_display)

# ============================================================
# Figure A: Side-by-side barplots (Multi vs Single) per output
# ============================================================
for out_idx, out_name in enumerate(target_cols):
    out_disp = target_display_map.get(out_name, out_name)

    # Multi and Single mean(|SHAP|)
    s_multi  = mean_abs_series(shap_values_per_output[out_idx], clinical_feature_names_display)
    s_single = mean_abs_series(single_shap_values[out_name],      clinical_feature_names_display)

    # Use a common TopK based on combined importance (to align bars)
    s_combined = (s_multi + s_single).sort_values(ascending=False)
    top_feats = s_combined.head(TOPK).index.tolist()

    df_bar = pd.DataFrame({
        "Multi-output":  s_multi.loc[top_feats].values,
        "Single-output": s_single.loc[top_feats].values,
    }, index=top_feats)

    # Plot grouped horizontal bars
    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    ax.set_facecolor("white")

    y = np.arange(len(top_feats))
    h = 0.38

    ax.barh(y - h/2, df_bar["Multi-output"],  height=h, color=COLOR_MULTI,  label="Multi-output")
    ax.barh(y + h/2, df_bar["Single-output"], height=h, color=COLOR_SINGLE, label="Single-output")

    ax.set_yticks(y)
    ax.set_yticklabels(top_feats, fontsize=FS_TICK)
    ax.invert_yaxis()

    ax.set_xlabel("Mean |SHAP value|", fontsize=FS_LABEL)
    #ax.set_title(f"Top-{TOPK} mean |SHAP| comparison — {out_disp}", fontsize=FS_TITLE)

    ax.tick_params(axis="both", which="major", labelsize=FS_TICK)
    ax.legend(loc="lower right", fontsize=FS_LEG)

    # If labels get clipped, increase left margin
    plt.tight_layout()
    # Alternatively, uncomment and adjust:
    # plt.subplots_adjust(left=0.30, right=0.98, top=0.90, bottom=0.12)

    save_current_fig(IMG_DIR / f"redesign_bar_compare_multi_vs_single_{out_idx}_top{TOPK}_cbc_v2")
    plt.show()
    plt.close() 

print(f"Saved redesigned BAR comparison plots (Top-{TOPK}) for all outputs.")

# ============================================================
# Figure B: Raw SHAP side-by-side beeswarms (Multi vs Single)
# ============================================================
"""
for out_idx, out_name in enumerate(target_cols):
    out_disp = target_display_map.get(out_name, out_name)

    sv_multi  = shap_values_per_output[out_idx]     # (n_samples, n_features)
    sv_single = single_shap_values[out_name]        # (n_samples, n_features)

    s_multi  = mean_abs_series(sv_multi,  clinical_feature_names_display)
    s_single = mean_abs_series(sv_single, clinical_feature_names_display)

    # Same TopK subset for both (based on combined importance)
    top_feats = (s_multi + s_single).sort_values(ascending=False).head(TOPK).index.tolist()
    feat_to_idx = {name: j for j, name in enumerate(clinical_feature_names_display)}
    top_idx = [feat_to_idx[f] for f in top_feats]

    sv_multi_sub  = sv_multi[:,  top_idx]
    sv_single_sub = sv_single[:, top_idx]

    # Use the same feature values subset
    X_sub = X_exp_unscaled.values[:, top_idx]

    fig = plt.figure(figsize=(14, 6))
    gs = fig.add_gridspec(1, 2, wspace=0.30)

    # Left: Multi-output raw SHAP
    ax1 = fig.add_subplot(gs[0, 0])
    plt.sca(ax1)
    shap.summary_plot(
        sv_multi_sub,
        features=X_sub,
        feature_names=top_feats,
        max_display=TOPK,
        show=False,
    )
    ax1 = plt.gca()
    ax1.set_title(f"Raw SHAP (beeswarm) — Multi-output — {out_disp}", fontsize=FS_TITLE)
    ax1.tick_params(axis="both", which="major", labelsize=FS_TICK)
    ax1.set_xlabel(ax1.get_xlabel(), fontsize=FS_LABEL)
    ax1.set_ylabel(ax1.get_ylabel(), fontsize=FS_LABEL)

    # Right: Single-output raw SHAP
    ax2 = fig.add_subplot(gs[0, 1])
    plt.sca(ax2)
    shap.summary_plot(
        sv_single_sub,
        features=X_sub,
        feature_names=top_feats,
        max_display=TOPK,
        show=False,
    )
    ax2 = plt.gca()
    ax2.set_title(f"Raw SHAP (beeswarm) — Single-output — {out_disp}", fontsize=FS_TITLE)
    ax2.tick_params(axis="both", which="major", labelsize=FS_TICK)
    ax2.set_xlabel(ax2.get_xlabel(), fontsize=FS_LABEL)
    ax2.set_ylabel(ax2.get_ylabel(), fontsize=FS_LABEL)

    plt.tight_layout()
    # If labels get clipped, increase margins:
    # plt.subplots_adjust(left=0.25, right=0.98, top=0.90, bottom=0.15, wspace=0.35)

    save_current_fig(IMG_DIR / f"redesign_beeswarm_compare_multi_vs_single_{out_name}_top{TOPK}_cbc_v2")
    plt.show()
"""
#print(f"Saved redesigned SIDE-BY-SIDE beeswarm (raw SHAP) plots (Top-{TOPK}) for all outputs.")


# %% [markdown]
# ## 17) Quantitative similarity of raw SHAP values
# ##
# # - Cosine similarity (per output)
# # - Spearman rank correlation (per output)
# # - Computed on RAW SHAP values (instance-wise, feature-wise)
# # - Results saved as CSV for manuscript reporting

# %%
from scipy.stats import spearmanr

def cosine_similarity(vec_a, vec_b, eps=1e-12):
    """
    Compute cosine similarity between two 1D vectors.
    """
    num = np.dot(vec_a, vec_b)
    den = (np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) + eps
    return float(num / den)

similarity_rows = []

for k, out_name in enumerate(target_cols):
    out_display = target_display_map.get(out_name, out_name)

    # --- RAW SHAP arrays ---
    # Multi-output: shap_values_per_output[k] -> (n_samples, n_features)
    # Single-output: single_shap_values[out_name] -> (n_samples, n_features)
    shap_multi  = shap_values_per_output[k]
    shap_single = single_shap_values[out_name]

    # Safety check
    if shap_multi.shape != shap_single.shape:
        raise ValueError(
            f"Shape mismatch for output {out_name}: "
            f"multi {shap_multi.shape}, single {shap_single.shape}"
        )

    # Vectorize (instance-wise + feature-wise)
    vec_multi  = shap_multi.reshape(-1)
    vec_single = shap_single.reshape(-1)

    # --- Cosine similarity ---
    cos_sim = cosine_similarity(vec_multi, vec_single)

    # --- Spearman rank correlation ---
    spearman_corr, spearman_p = spearmanr(vec_multi, vec_single)

    similarity_rows.append({
        "output_raw": out_name,
        "output_display": out_display,
        "n_samples": shap_multi.shape[0],
        "n_features": shap_multi.shape[1],
        "cosine_similarity": cos_sim,
        "spearman_correlation": float(spearman_corr),
        "spearman_pvalue": float(spearman_p),
    })

# Build DataFrame
similarity_df = pd.DataFrame(similarity_rows)

# Save per-output similarities
csv_path = TAB_DIR / "shap_local_similarity_multi_vs_single_cbc_v2.csv"
similarity_df.to_csv(csv_path, index=False)

print("\n=== Local SHAP similarity (multi vs single) — per output ===")
print(similarity_df)
print(f"\nSaved SHAP similarity metrics to:\n{csv_path}")

# --- Optional: also save a compact summary (mean ± std across outputs) ---
summary_df = pd.DataFrame({
    "metric": ["cosine_similarity", "spearman_correlation"],
    "mean": [
        similarity_df["cosine_similarity"].mean(),
        similarity_df["spearman_correlation"].mean(),
    ],
    "std": [
        similarity_df["cosine_similarity"].std(ddof=1),
        similarity_df["spearman_correlation"].std(ddof=1),
    ],
})

summary_path = TAB_DIR / "shap_local_similarity_summary_multi_vs_single_cbc_v2.csv"
summary_df.to_csv(summary_path, index=False)

print("\n=== Summary across outputs (mean ± std) ===")
print(summary_df)
print(f"\nSaved summary to:\n{summary_path}")


# %% [markdown]
# ## 18) Local SHAP similarity restricted to Top-5 features (per output)
# #
# # - Top-5 features selected per output based on combined mean |SHAP|
# #   (multi-output + single-output)
# # - Cosine similarity and Spearman correlation computed on RAW SHAP values
# # - Metrics computed separately for each output
# # - Results saved as CSV

# %%
from scipy.stats import spearmanr

def cosine_similarity(vec_a, vec_b, eps=1e-12):
    num = np.dot(vec_a, vec_b)
    den = (np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) + eps
    return float(num / den)

topK = 10
similarity_topk_rows = []

# Precompute feature index mapping (display name -> index)
feat_to_idx = {name: j for j, name in enumerate(clinical_feature_names_display)}

for k, out_name in enumerate(target_cols):
    out_display = target_display_map.get(out_name, out_name)

    # --- Raw SHAP arrays ---
    shap_multi  = shap_values_per_output[k]          # (n_samples, n_features)
    shap_single = single_shap_values[out_name]       # (n_samples, n_features)

    # --- Mean |SHAP| per feature ---
    mean_multi  = np.mean(np.abs(shap_multi),  axis=0)
    mean_single = np.mean(np.abs(shap_single), axis=0)

    # Combined importance to define Top-5 consistently
    combined_importance = mean_multi + mean_single
    topk_idx = np.argsort(combined_importance)[::-1][:topK]

    topk_features_display = [clinical_feature_names_display[j] for j in topk_idx]

    # --- Restrict SHAP arrays to Top-5 features ---
    shap_multi_topk  = shap_multi[:,  topk_idx]
    shap_single_topk = shap_single[:, topk_idx]

    # Vectorize
    vec_multi  = shap_multi_topk.reshape(-1)
    vec_single = shap_single_topk.reshape(-1)

    # --- Similarity metrics ---
    cos_sim = cosine_similarity(vec_multi, vec_single)
    spearman_corr, spearman_p = spearmanr(vec_multi, vec_single)

    similarity_topk_rows.append({
        "output_raw": out_name,
        "output_display": out_display,
        "topK": topK,
        "topK_features_display": "; ".join(topk_features_display),
        "n_samples": shap_multi_topk.shape[0],
        "cosine_similarity_topK": cos_sim,
        "spearman_correlation_topK": float(spearman_corr),
        "spearman_pvalue_topK": float(spearman_p),
    })

# Build DataFrame
similarity_topk_df = pd.DataFrame(similarity_topk_rows)

# Save per-output Top-K similarities
csv_path = TAB_DIR / "shap_local_similarity_top5_multi_vs_single_cbc_v2.csv"
similarity_topk_df.to_csv(csv_path, index=False)

print("\n=== Local SHAP similarity (Top-5 features, multi vs single) ===")
print(similarity_topk_df)
print(f"\nSaved Top-5 SHAP similarity metrics to:\n{csv_path}")

# --- Optional: summary across outputs ---
summary_topk_df = pd.DataFrame({
    "metric": ["cosine_similarity_topK", "spearman_correlation_topK"],
    "mean": [
        similarity_topk_df["cosine_similarity_topK"].mean(),
        similarity_topk_df["spearman_correlation_topK"].mean(),
    ],
    "std": [
        similarity_topk_df["cosine_similarity_topK"].std(ddof=1),
        similarity_topk_df["spearman_correlation_topK"].std(ddof=1),
    ],
})

summary_path = TAB_DIR / "shap_local_similarity_top5_summary_multi_vs_single_cbc_v2.csv"
summary_topk_df.to_csv(summary_path, index=False)

print("\n=== Top-5 summary across outputs (mean ± std) ===")
print(summary_topk_df)
print(f"\nSaved Top-5 summary to:\n{summary_path}")


In [None]:
# %% [markdown]
# ## 19) Combined figure: Top-5 mean |SHAP| (Multi vs Single) for all outputs

# %%
print("\n[FIG19] Starting combined TopK mean|SHAP| figure...")

# --- Quick diagnostics (do NOT crash) ---
diag_vars = [
    "target_cols",
    "target_display_map",
    "clinical_feature_names_display",
    "shap_values_per_output",
    "single_shap_values",
    "IMG_DIR",
    "save_current_fig",
    "TOPK",
]
for v in diag_vars:
    print(f"[FIG19] {v} in globals? ->", v in globals())

# Fallbacks if some style vars are missing
TOPK = globals().get("TOPK", 5)
COLOR_MULTI  = globals().get("COLOR_MULTI",  "#1f77b4")
COLOR_SINGLE = globals().get("COLOR_SINGLE", "#ff7f0e")
FS_TITLE = globals().get("FS_TITLE", 16)
FS_LABEL = globals().get("FS_LABEL", 14)
FS_TICK  = globals().get("FS_TICK",  12)
FS_LEG   = globals().get("FS_LEG",   12)

print(f"[FIG19] Using TOPK={TOPK}")

# --- Define helper (local) ---
def _mean_abs_series(shap_array, feature_names_display):
    import numpy as np
    import pandas as pd
    return pd.Series(np.mean(np.abs(shap_array), axis=0), index=feature_names_display)

# --- Assert minimally required data (fail with clear message) ---
assert "target_cols" in globals() and len(target_cols) > 0, "[FIG19] target_cols missing/empty."
assert "clinical_feature_names_display" in globals(), "[FIG19] clinical_feature_names_display missing."
assert "shap_values_per_output" in globals(), "[FIG19] shap_values_per_output missing."
assert "single_shap_values" in globals(), "[FIG19] single_shap_values missing."
assert "IMG_DIR" in globals(), "[FIG19] IMG_DIR missing."
assert "save_current_fig" in globals(), "[FIG19] save_current_fig missing."

print("[FIG19] Basic inputs OK.")
print("[FIG19] n_outputs =", len(target_cols))
print("[FIG19] n_features_display =", len(clinical_feature_names_display))
print("[FIG19] shap_values_per_output lens =", len(shap_values_per_output))
print("[FIG19] single_shap_values keys =", list(single_shap_values.keys()))

# --- Build the combined figure ---
import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, len(target_cols), figsize=(20, 6), constrained_layout=True)
if len(target_cols) == 1:
    axes = [axes]

bar_width = 0.38

for out_idx, out_name in enumerate(target_cols):
    ax = axes[out_idx]
    out_disp = target_display_map.get(out_name, out_name) if "target_display_map" in globals() else out_name

    # Mean(|SHAP|) for multi and single
    s_multi  = _mean_abs_series(shap_values_per_output[out_idx], clinical_feature_names_display)
    s_single = _mean_abs_series(single_shap_values[out_name],     clinical_feature_names_display)

    # Common TopK based on combined importance
    s_combined = (s_multi + s_single).sort_values(ascending=False)
    top_feats = s_combined.head(TOPK).index.tolist()

    y_multi  = s_multi.loc[top_feats].values
    y_single = s_single.loc[top_feats].values

    x = np.arange(len(top_feats))

    ax.bar(x - bar_width/2, y_multi,  width=bar_width, color=COLOR_MULTI,  label="Multi-output")
    ax.bar(x + bar_width/2, y_single, width=bar_width, color=COLOR_SINGLE, label="Single-output")

    ax.set_xticks(x)
    ax.set_xticklabels(top_feats, rotation=90, ha="center", fontsize=FS_TICK)

    if out_idx == 0:
        ax.set_ylabel("Mean |SHAP value|", fontsize=FS_LABEL)

    ax.set_title(out_disp, fontsize=FS_TITLE)
    ax.tick_params(axis="y", labelsize=FS_TICK)
    ax.grid(axis="y", alpha=0.25)

# One legend for whole figure
handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=FS_LEG, frameon=True)

# Make room for 90° labels + legend
plt.subplots_adjust(bottom=0.30)

out_path = IMG_DIR / f"combined_top{TOPK}_bar_mean_abs_shap_multi_vs_single_all_outputs_cbc_v2"
save_current_fig(out_path)

print(f"[FIG19] Saved figure to: {out_path.with_suffix('.png')} and {out_path.with_suffix('.pdf')}")

plt.show()
plt.close(fig)

print("[FIG19] Done.")
