In [None]:
# %% [markdown]
# # CBC Multi-Output (v2) — LINEAR (Ridge) — HGB, HCT, RBC
#
# - Usa el dataset: cbc_synthetic_30000_enriched_v2.csv
# - Elimina features fuertemente correlacionadas (|corr| > 0.80) antes de entrenar
# - Entrena modelo multi-output y modelos single-output (lineales)
# - Calcula SHAP (LinearExplainer + masker)
# - Usa nombres "bonitos" para features y outputs en tablas y gráficos
# - Mide tiempo y memoria de entrenamiento y SHAP (multi vs single-output)
#   - Memoria medida como TOTAL PEAK en MB (decimal), NO delta.

# %%
import os, json, time
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import shap
from memory_profiler import memory_usage  # devuelve MiB

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import Ridge  # modelo lineal estable (multioutput nativo)

from sklearn.metrics import (
    r2_score,
    mean_absolute_error,
    mean_squared_error,
    f1_score
)

# -----------------------------
# Directorios base
# -----------------------------
BASE_DIR = Path(".")
OUT_DIR  = BASE_DIR / "cbc_multi_output_v2_outputs_time_memory_MB_05_linear_ridge"
IMG_DIR  = OUT_DIR / "figs"
TAB_DIR  = OUT_DIR / "tables"

for d in [OUT_DIR, IMG_DIR, TAB_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def _safe_stem(s: str) -> str:
    """
    Evita caracteres problemáticos en Windows para nombres de archivo.
    """
    bad = '<>:"/\\|?*'
    for ch in bad:
        s = s.replace(ch, "_")
    return s

def save_current_fig(path_no_ext: Path):
    """
    Guarda la figura actual en .png y .pdf.
    - En Windows a veces el PDF queda corrupto si no se fuerza format o no se cierra la figura.
    """
    path_no_ext = Path(path_no_ext)
    png_path = path_no_ext.with_suffix(".png")
    pdf_path = path_no_ext.with_suffix(".pdf")

    plt.savefig(png_path, bbox_inches="tight", dpi=300, format="png")
    plt.savefig(pdf_path, bbox_inches="tight", dpi=300, format="pdf")

def rmse_compat(y_true, y_pred):
    return np.sqrt(np.mean((y_true - y_pred) ** 2))

# ============================
# Registro runtime/memory (MB decimal)
# ============================
runtime_records = []

def add_runtime_record(stage: str, detail: str, time_sec: float, mem_peak_mb: float, mem_start_mb: float, mem_end_mb: float):
    runtime_records.append({
        "stage": stage,
        "detail": detail,
        "time_seconds": float(time_sec),
        "memory_peak_MB": float(mem_peak_mb),
        "memory_start_MB": float(mem_start_mb),
        "memory_end_MB": float(mem_end_mb),
    })

def profile_stage_total_MB(fn, interval: float = 0.05):
    """
    Ejecuta fn() y mide:
      - tiempo total
      - memoria TOTAL (RSS) en MiB (según memory_profiler), convertida a MB (decimal)
      - peak/start/end en MB
    """
    start_t = time.perf_counter()
    mem_trace_mib, out = memory_usage((fn, (), {}), retval=True, interval=interval)
    end_t = time.perf_counter()

    mib_to_mb = 1.048576  # MB decimales
    mem_start_mb = float(mem_trace_mib[0] * mib_to_mb)
    mem_end_mb   = float(mem_trace_mib[-1] * mib_to_mb)
    mem_peak_mb  = float(max(mem_trace_mib) * mib_to_mb)

    return out, (end_t - start_t), mem_peak_mb, mem_start_mb, mem_end_mb

print("OK: imports done.")


# %% [markdown]
# ## 1) Load dataset, drop highly correlated features, define display names

# %%
csv_path = BASE_DIR / "cbc_synthetic_30000_enriched_v2.csv"
df = pd.read_csv(csv_path)

print("Shape (raw):", df.shape)
print("Columns:", list(df.columns))

# Targets crudos
target_cols = ["y_hgb_gdl", "y_hct_pct", "y_rbc_10^12_per_L"]

# Nombres bonitos outputs
target_display_map = {
    "y_hgb_gdl": "HGB (g/dL)",
    "y_hct_pct": "HCT (%)",
    "y_rbc_10^12_per_L": "RBC (10^12/L)",
}
target_display_names = [target_display_map[t] for t in target_cols]

# Features crudas
feature_cols_raw = [c for c in df.columns if c not in target_cols]

# Mapeo crudo -> bonito
feature_display_map = {
    "age_years": "Age (years)",
    "sex_male": "Sex (male=1)",
    "bmi": "BMI",
    "iron_ugdl": "Iron (µg/dL)",
    "ferritin_ngml": "Ferritin (ng/mL)",
    "vitamin_d_ngml": "Vitamin D (ng/mL)",
    "folate_ngml": "Folate (ng/mL)",
    "vitamin_b12_pgml": "Vitamin B12 (pg/mL)",
    "crp_mgL": "CRP (mg/L)",
    "albumin_gdl": "Albumin (g/dL)",
    "creatinine_mgdl": "Creatinine (mg/dL)",
    "egfr_ml_min": "eGFR (mL/min)",
    "sbp_mmHg": "Systolic BP (mmHg)",
    "dbp_mmHg": "Diastolic BP (mmHg)",
    "heart_rate_bpm": "Heart rate (bpm)",
    "wbc_10^9_per_L": "WBC (10^9/L)",
    "plt_10^9_per_L": "Platelets (10^9/L)",
    "smoker": "Smoker (1 = yes)",
    "alcohol_units_per_week": "Alcohol (units/week)",
    "physical_activity_level": "Physical activity level",
}

# Correlación features y eliminación |corr| > 0.80
corr_features = df[feature_cols_raw].corr()
corr_abs = corr_features.abs()

to_drop = set()
cols = feature_cols_raw
threshold = 0.80

for i in range(len(cols)):
    for j in range(i + 1, len(cols)):
        if corr_abs.iloc[i, j] > threshold:
            to_drop.add(cols[j])  # simple: drop col_j

print("\nHighly correlated feature pairs (|corr| > 0.80):")
for i in range(len(cols)):
    for j in range(i + 1, len(cols)):
        if corr_abs.iloc[i, j] > threshold:
            print(f"{cols[i]}  <->  {cols[j]} : corr = {corr_features.iloc[i, j]:.3f}")

print("\nFeatures to drop due to high correlation:", to_drop)

feature_cols = [c for c in feature_cols_raw if c not in to_drop]

print("\nNumber of features before:", len(feature_cols_raw))
print("Number of features after drop:", len(feature_cols))
print("Final feature list (kept):")
print(feature_cols)

display_feature_names = [feature_display_map.get(c, c) for c in feature_cols]

with open(OUT_DIR / "dropped_features_due_corr.json", "w") as f:
    json.dump({
        "threshold": threshold,
        "dropped_features": sorted(list(to_drop)),
        "final_features_raw": feature_cols,
        "final_features_display": display_feature_names
    }, f, indent=2)

print("\nTargets description:")
print(df[target_cols].describe())

corr_outputs = df[target_cols].corr()
print("\nCorrelation among outputs:")
print(corr_outputs)

corr_outputs.to_csv(TAB_DIR / "corr_outputs_true.csv", index_label="target_raw")

plt.figure()
plt.imshow(corr_outputs.values, cmap="coolwarm", vmin=-1, vmax=1)
plt.colorbar()
plt.xticks(range(len(target_cols)), target_display_names)
plt.yticks(range(len(target_cols)), target_display_names)
plt.title("Correlation — TRUE outputs (CBC v2)")
plt.tight_layout()
save_current_fig(IMG_DIR / "corr_heatmap_true_outputs_cbc_v2")
plt.show()
plt.close()


# %% [markdown]
# ## 2) Train/Val/Test split and scaling

# %%
X = df[feature_cols].values.astype(np.float32)
Y = df[target_cols].values.astype(np.float32)

print("X shape:", X.shape, "Y shape:", Y.shape)

X_train, X_temp, Y_train, Y_temp = train_test_split(
    X, Y, test_size=0.2, random_state=42
)
X_val, X_test, Y_val, Y_test = train_test_split(
    X_temp, Y_temp, test_size=0.5, random_state=42
)

print("Train:", X_train.shape, "Val:", X_val.shape, "Test:", X_test.shape)

scaler_X = StandardScaler()
scaler_Y = StandardScaler()

X_train_scaled = scaler_X.fit_transform(X_train)
X_val_scaled   = scaler_X.transform(X_val)
X_test_scaled  = scaler_X.transform(X_test)

Y_train_scaled = scaler_Y.fit_transform(Y_train)
Y_val_scaled   = scaler_Y.transform(Y_val)
Y_test_scaled  = scaler_Y.transform(Y_test)

scaler_info = {
    "feature_cols_raw": feature_cols,
    "feature_cols_display": display_feature_names,
    "target_cols_raw": target_cols,
    "target_cols_display": target_display_names,
    "X_mean": scaler_X.mean_.tolist(),
    "X_scale": scaler_X.scale_.tolist(),
    "Y_mean": scaler_Y.mean_.tolist(),
    "Y_scale": scaler_Y.scale_.tolist(),
}
with open(OUT_DIR / "scalers_cbc_v2.json", "w") as f:
    json.dump(scaler_info, f, indent=2)

clinical_feature_names = feature_cols
clinical_feature_names_display = display_feature_names


# %% [markdown]
# ## 3) Train multi-output linear model (Ridge)

# %%
# Ridge multi-output: coef_ shape (n_targets, n_features)
# (alpha=0 -> equivalente a OLS, pero Ridge suele ser más estable)
multi_model = Ridge(alpha=1.0, random_state=42)

def _train_multi_core():
    multi_model.fit(X_train_scaled, Y_train_scaled)
    return None

_, train_time_multi, mem_peak_mb, mem_start_mb, mem_end_mb = profile_stage_total_MB(_train_multi_core, interval=0.05)

print(f"\n[Multi-output LINEAR Ridge] Training time: {train_time_multi:.3f} s")
print(f"[Multi-output LINEAR Ridge] Training memory PEAK: {mem_peak_mb:.2f} MB (start={mem_start_mb:.2f}, end={mem_end_mb:.2f})")

add_runtime_record(
    stage="train_multi",
    detail="multi-output-linear-ridge",
    time_sec=train_time_multi,
    mem_peak_mb=mem_peak_mb,
    mem_start_mb=mem_start_mb,
    mem_end_mb=mem_end_mb,
)


# %% [markdown]
# ## 4) Evaluation of multi-output model

# %%
def eval_split_multi(split_name, X_scaled, Y_real):
    preds_std = multi_model.predict(X_scaled)  # (n_samples, n_outputs) en escala Y_train_scaled
    preds_real = scaler_Y.inverse_transform(preds_std)

    metrics_rows = []
    for k, out_name in enumerate(target_cols):
        out_display = target_display_map.get(out_name, out_name)
        y_true_k = Y_real[:, k]
        y_pred_k = preds_real[:, k]

        r2   = r2_score(y_true_k, y_pred_k)
        rmse = rmse_compat(y_true_k, y_pred_k)
        mae  = mean_absolute_error(y_true_k, y_pred_k)

        median_thr = np.median(Y_train[:, k])
        y_true_bin = (y_true_k > median_thr).astype(int)
        y_pred_bin = (y_pred_k > median_thr).astype(int)
        f1 = f1_score(y_true_bin, y_pred_bin)

        metrics_rows.append({
            "split": split_name,
            "target_raw": out_name,
            "target_display": out_display,
            "R2": r2,
            "RMSE": rmse,
            "MAE": mae,
            "F1_median": f1,
        })

    return pd.DataFrame(metrics_rows), preds_real

m_tr, preds_tr   = eval_split_multi("train", X_train_scaled, Y_train)
m_val, preds_val = eval_split_multi("val",   X_val_scaled,   Y_val)
m_te, preds_te   = eval_split_multi("test",  X_test_scaled,  Y_test)

metrics_multi = pd.concat([m_tr, m_val, m_te], ignore_index=True)
metrics_multi.to_csv(TAB_DIR / "metrics_multi_linear_ridge_cbc_v2.csv", index=False)
print(metrics_multi)

residuals_test = Y_test - preds_te
res_corr = pd.DataFrame(residuals_test, columns=target_cols).corr()
print("\nResiduals correlation (test):")
print(res_corr)
res_corr.to_csv(TAB_DIR / "residuals_corr_multi_linear_ridge_cbc_v2_test.csv", index_label="target_raw")

plt.figure()
plt.imshow(res_corr.values, cmap="coolwarm", vmin=-1, vmax=1)
plt.colorbar()
plt.xticks(range(len(target_cols)), target_display_names)
plt.yticks(range(len(target_cols)), target_display_names)
plt.title("Residuals correlation — MULTI-OUTPUT LINEAR (test, v2)")
plt.tight_layout()
save_current_fig(IMG_DIR / "residuals_corr_multi_linear_ridge_cbc_v2_test")
plt.show()
plt.close()


# %% [markdown]
# ## 5) SHAP — background & explained sets

# %%
bg_size  = min(1000, X_train_scaled.shape[0])
exp_size = min(2000, X_test_scaled.shape[0])

rng = np.random.default_rng(0)
bg_idx  = rng.choice(X_train_scaled.shape[0], size=bg_size, replace=False)
exp_idx = rng.choice(X_test_scaled.shape[0],  size=exp_size, replace=False)

X_bg_scaled  = X_train_scaled[bg_idx]
X_exp_scaled = X_test_scaled[exp_idx]

X_exp_np = X_exp_scaled.copy()
X_exp_view = pd.DataFrame(X_exp_np, columns=clinical_feature_names)

X_exp_unscaled = pd.DataFrame(
    scaler_X.inverse_transform(X_exp_np),
    columns=clinical_feature_names
)

print("Background shape:", X_bg_scaled.shape)
print("Explained shape:", X_exp_scaled.shape)
print("Raw feature names order:", clinical_feature_names)
print("Display feature names:", clinical_feature_names_display)


# %% [markdown]
# ## 6) Compute SHAP values for multi-output linear model

# %%
n_outputs  = len(target_cols)
n_features = len(clinical_feature_names)

def _compute_shap_multi_core():
    # Evita warning de feature_perturbation usando masker
    masker = shap.maskers.Independent(X_bg_scaled)
    explainer_multi = shap.LinearExplainer(multi_model, masker=masker)
    shap_values_raw = explainer_multi.shap_values(X_exp_scaled)
    return shap_values_raw

shap_values_raw, shap_time_multi, mem_peak_mb, mem_start_mb, mem_end_mb = profile_stage_total_MB(
    _compute_shap_multi_core, interval=0.05
)

print(f"\n[SHAP multi-output LINEAR] Time: {shap_time_multi:.3f} s")
print(f"[SHAP multi-output LINEAR] Memory PEAK: {mem_peak_mb:.2f} MB (start={mem_start_mb:.2f}, end={mem_end_mb:.2f})")

add_runtime_record(
    stage="shap_multi",
    detail="multi-output-linear",
    time_sec=shap_time_multi,
    mem_peak_mb=mem_peak_mb,
    mem_start_mb=mem_start_mb,
    mem_end_mb=mem_end_mb,
)

# Normalizar a lista por output: shap_values_per_output[k] con shape (n_samples, n_features)
shap_values_per_output = []

if isinstance(shap_values_raw, (list, tuple)):
    # esperable: list length = n_outputs
    if len(shap_values_raw) != n_outputs:
        raise ValueError(f"Expected {n_outputs} outputs in SHAP list, got {len(shap_values_raw)}")
    for k in range(n_outputs):
        arr = np.array(shap_values_raw[k])
        arr = np.squeeze(arr)
        if arr.ndim == 1:
            arr = arr.reshape(-1, n_features)
        if arr.shape[1] != n_features:
            raise ValueError(f"Unexpected SHAP shape {arr.shape} for output {target_cols[k]}")
        shap_values_per_output.append(arr)
else:
    # a veces shap devuelve array (n_samples, n_features, n_outputs)
    arr = np.array(shap_values_raw)
    if arr.ndim == 3 and arr.shape[2] == n_outputs:
        for k in range(n_outputs):
            shap_values_per_output.append(arr[:, :, k])
    else:
        raise ValueError(f"Cannot interpret SHAP output shape: {arr.shape}")

print("Multi-output SHAP shapes per output:")
for k, out_name in enumerate(target_cols):
    print(out_name, shap_values_per_output[k].shape)

for i, out_name in enumerate(target_cols):
    np.save(OUT_DIR / f"shap_values_multi_linear_{out_name}_cbc_v2.npy", shap_values_per_output[i])

with open(OUT_DIR / "shap_multi_meta_linear_cbc_v2.json", "w") as f:
    json.dump({
        "mode": "multi-output-linear",
        "method": "LinearExplainer",
        "background_size": int(X_bg_scaled.shape[0]),
        "explain_size": int(X_exp_scaled.shape[0]),
        "targets_raw": target_cols,
        "targets_display": target_display_names,
        "n_features_used": int(X_exp_scaled.shape[1]),
        "feature_names_raw": clinical_feature_names,
        "feature_names_display": clinical_feature_names_display,
        "model": "Ridge(alpha=1.0)",
    }, f, indent=2)

print("Saved multi-output LINEAR SHAP arrays (.npy) and metadata (.json) for v2.")


# %% [markdown]
# ## 7) SHAP multi-output — bar, beeswarm, heatmap, dependence (display names)

# %%
# mean |SHAP| por salida
mean_abs = {}
for i, out_name in enumerate(target_cols):
    vals = shap_values_per_output[i]  # (n_samples, n_features)
    s = pd.Series(np.mean(np.abs(vals), axis=0), index=clinical_feature_names_display, name=out_name)
    mean_abs[out_name] = s

mean_abs_df = pd.DataFrame(mean_abs)
mean_abs_df.to_csv(TAB_DIR / "mean_abs_shap_multi_linear_cbc_v2.csv", index_label="feature_display")

# Bar + beeswarm (Top-5) por salida
for i, out_name in enumerate(target_cols):
    out_display = target_display_map.get(out_name, out_name)

    s = mean_abs[out_name].sort_values(ascending=False)
    #top5 = s.head(5)
    top5=s.head(10)

    plt.figure(figsize=(8, 5))
    ax = plt.gca()
    ax.set_facecolor("white")
    top5[::-1].plot(kind="barh", color="#1f77b4")
    plt.title(f"Top-5 mean |SHAP| — Multi-output LINEAR — {out_display}")
    plt.xlabel("Mean |SHAP value|")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_bar_multi_linear_{_safe_stem(out_name)}_cbc_v2")
    plt.show()
    plt.close()

    plt.figure(figsize=(8, 6))
    shap.summary_plot(
        shap_values_per_output[i],
        features=X_exp_unscaled,
        feature_names=clinical_feature_names_display,
        max_display=6,
        cmap="viridis",
        show=False,
    )
    #plt.title(f"Beeswarm Top-6 — Multi-output LINEAR — {out_display}")
    plt.title(f"Linear Multi-output model — {out_display}")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_beeswarm_multi_linear_{_safe_stem(out_name)}_cbc_v2")
    plt.show()
    plt.close()

print("Saved Top-6 bar & beeswarm plots (multi-output LINEAR) with display names.")

# Heatmap Top-10 (sumando salidas)
combined = pd.DataFrame(
    {out_name: mean_abs[out_name].values for out_name in target_cols},
    index=clinical_feature_names_display,
)
overall = combined.abs().sum(axis=1).sort_values(ascending=False)
topK_heat = 10
top_features = overall.head(topK_heat).index.tolist()

heat_data = combined.loc[top_features].T  # (n_outputs, topK)

plt.figure(figsize=(10, 6))
plt.imshow(heat_data.values, aspect="auto", cmap="coolwarm")
plt.colorbar(label="Mean |SHAP value|")
plt.xticks(np.arange(len(top_features)), top_features, rotation=45, ha="right")
plt.yticks(np.arange(len(target_cols)), target_display_names)
plt.title("Top-10 features — Mean |SHAP| (multi-output LINEAR, CBC v2)")
plt.tight_layout()
save_current_fig(IMG_DIR / "shap_heatmap_multi_linear_top10_cbc_v2")
plt.show()
plt.close()

print("Saved heatmap for Top-10 features (multi-output LINEAR).")

# Dependence plot: top-1 por salida
name_to_idx = {disp: j for j, disp in enumerate(clinical_feature_names_display)}
top1_idx = {}
for out_name in target_cols:
    s = mean_abs[out_name].sort_values(ascending=False)
    feat_display = s.index[0]
    feat_idx = name_to_idx[feat_display]
    top1_idx[out_name] = (feat_idx, feat_display)

print("Top-1 features (multi-output LINEAR):", {k: v[1] for k, v in top1_idx.items()})

for out_name in target_cols:
    out_display = target_display_map.get(out_name, out_name)
    feat_idx, feat_display = top1_idx[out_name]

    plt.figure()
    shap.dependence_plot(
        ind=feat_idx,
        shap_values=shap_values_per_output[target_cols.index(out_name)],
        features=X_exp_unscaled,
        feature_names=clinical_feature_names_display,
        show=False,
    )
    plt.title(f"SHAP dependence — Multi-output LINEAR — {out_display} vs {feat_display}")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_dependence_multi_linear_{_safe_stem(out_name)}_feat{feat_idx}_cbc_v2")
    plt.show()
    plt.close()

print("Saved dependence plots (multi-output LINEAR) with display names.")


# %% [markdown]
# ## 8) Single-output linear models (Ridge) — training

# %%
single_models = {}

def train_single_output_model(k, out_name):
    print(f"\n=== Training single-output LINEAR Ridge model for {out_name} ===")

    model_k = Ridge(alpha=1.0, random_state=42)

    def _train_single_core():
        model_k.fit(X_train_scaled, Y_train_scaled[:, k])
        return None

    _, t_sec, mem_peak, mem_start, mem_end = profile_stage_total_MB(_train_single_core, interval=0.05)

    print(f"[Single LINEAR {out_name}] Training time: {t_sec:.3f} s")
    print(f"[Single LINEAR {out_name}] Training memory PEAK: {mem_peak:.2f} MB (start={mem_start:.2f}, end={mem_end:.2f})")

    add_runtime_record(
        stage="train_single",
        detail=f"{out_name}-linear-ridge",
        time_sec=t_sec,
        mem_peak_mb=mem_peak,
        mem_start_mb=mem_start,
        mem_end_mb=mem_end,
    )

    single_models[out_name] = model_k

for k, out_name in enumerate(target_cols):
    train_single_output_model(k, out_name)

print("Done: trained 3 single-output linear models.")


# %% [markdown]
# ## 9) SHAP for single-output linear models

# %%
single_shap_values = {}

for k, out_name in enumerate(target_cols):
    print(f"\n=== Computing SHAP for single-output LINEAR: {out_name} ===")
    model_k = single_models[out_name]

    def _compute_shap_single_core():
        masker = shap.maskers.Independent(X_bg_scaled)
        explainer_k = shap.LinearExplainer(model_k, masker=masker)
        shap_raw = explainer_k.shap_values(X_exp_scaled)  # (n_samples, n_features)
        arr = np.array(shap_raw)
        arr = np.squeeze(arr)
        if arr.ndim == 1:
            arr = arr.reshape(-1, n_features)
        return arr

    shap_arr, t_sec, mem_peak, mem_start, mem_end = profile_stage_total_MB(_compute_shap_single_core, interval=0.05)

    print(f"[SHAP single LINEAR {out_name}] Time: {t_sec:.3f} s")
    print(f"[SHAP single LINEAR {out_name}] Memory PEAK: {mem_peak:.2f} MB (start={mem_start:.2f}, end={mem_end:.2f})")

    add_runtime_record(
        stage="shap_single",
        detail=f"{out_name}-linear",
        time_sec=t_sec,
        mem_peak_mb=mem_peak,
        mem_start_mb=mem_start,
        mem_end_mb=mem_end,
    )

    single_shap_values[out_name] = shap_arr
    np.save(OUT_DIR / f"shap_values_single_linear_{_safe_stem(out_name)}_cbc_v2.npy", shap_arr)

with open(OUT_DIR / "shap_single_meta_linear_cbc_v2.json", "w") as f:
    json.dump({
        "mode": "single-output-linear-per-target",
        "method": "LinearExplainer",
        "background_size": int(X_bg_scaled.shape[0]),
        "explain_size": int(X_exp_scaled.shape[0]),
        "targets_raw": target_cols,
        "targets_display": target_display_names,
        "n_features_used": int(X_exp_scaled.shape[1]),
        "feature_names_raw": clinical_feature_names,
        "feature_names_display": clinical_feature_names_display,
        "model": "Ridge(alpha=1.0)",
    }, f, indent=2)

print("Saved single-output LINEAR SHAP arrays and metadata.")


# %% [markdown]
# ## 10) SHAP single-output — mean |SHAP|, bar & beeswarm

# %%
mean_abs_single = {}
for out_name in target_cols:
    v = np.abs(single_shap_values[out_name]).mean(axis=0)
    s = pd.Series(v, index=clinical_feature_names_display).sort_values(ascending=False)
    mean_abs_single[out_name] = s
    s.to_csv(
        TAB_DIR / f"mean_abs_shap_single_linear_{_safe_stem(out_name)}_cbc_v2.csv",
        header=["mean_abs_shap"],
        index_label="feature_display",
    )

combined_single = pd.DataFrame(
    {out_name: mean_abs_single[out_name].values for out_name in target_cols},
    index=clinical_feature_names_display,
)
combined_single.to_csv(
    TAB_DIR / "mean_abs_shap_single_linear_all_cbc_v2.csv",
    index_label="feature_display",
)

print("Saved mean |SHAP| tables for single-output LINEAR models.")

for out_name in target_cols:
    out_display = target_display_map.get(out_name, out_name)
    s = mean_abs_single[out_name].sort_values(ascending=False)
    #top5 = s.head(5)
    top5=s.head(10)
    plt.figure(figsize=(8, 5))
    ax = plt.gca()
    ax.set_facecolor("white")
    top5[::-1].plot(kind="barh", color="#1f77b4")
    plt.title(f"Top-5 mean |SHAP| — Single-output LINEAR — {out_display}")
    plt.xlabel("Mean |SHAP value|")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_bar_single_linear_{_safe_stem(out_name)}_cbc_v2")
    plt.show()
    plt.close()

    plt.figure(figsize=(8, 6))
    shap.summary_plot(
        single_shap_values[out_name],
        X_exp_view,
        feature_names=clinical_feature_names_display,
        max_display=6,
        show=False,
    )
    plt.title(f"SHAP beeswarm — Single-output LINEAR — {out_display}")
    plt.tight_layout()
    save_current_fig(IMG_DIR / f"shap_beeswarm_top5_single_linear_{_safe_stem(out_name)}_cbc_v2")
    plt.show()
    plt.close()

print("Saved Top-5 bar & beeswarm plots for single-output LINEAR models.")


# %% [markdown]
# ## 11) SHAP single-output — heatmap, dependence, cross-output SHAP corr

# %%
overall_single = combined_single.abs().sum(axis=1).sort_values(ascending=False)
topK_single = 10
top_features_single = overall_single.head(topK_single).index.tolist()

heat_single = combined_single.loc[top_features_single, :].values

plt.figure(figsize=(8, 6))
plt.imshow(heat_single, aspect="auto")
plt.colorbar()
plt.xticks(range(len(target_display_names)), target_display_names, rotation=0)
plt.yticks(range(len(top_features_single)), top_features_single)
plt.title("Mean |SHAP| (single-output LINEAR, v2) — Top features vs targets")
plt.tight_layout()
save_current_fig(IMG_DIR / "shap_heatmap_single_linear_top_features_cbc_v2")
plt.show()
plt.close()

top1_single_idx = {}
name_to_idx = {disp: j for j, disp in enumerate(clinical_feature_names_display)}
for out_name in target_cols:
    s = mean_abs_single[out_name]
    feat_display = s.index[0]
    feat_idx = name_to_idx[feat_display]
    top1_single_idx[out_name] = (feat_idx, feat_display)

print("Top-1 features (single-output LINEAR):", {k: v[1] for k, v in top1_single_idx.items()})

for out_name in target_cols:
    out_display = target_display_map.get(out_name, out_name)
    feat_idx, feat_display = top1_single_idx[out_name]

    plt.figure()
    shap.dependence_plot(
        ind=feat_idx,
        shap_values=single_shap_values[out_name],
        features=X_exp_unscaled,
        feature_names=clinical_feature_names_display,
        show=False
    )
    plt.title(f"SHAP dependence — Single-output LINEAR — {out_display} vs {feat_display}")
    save_current_fig(IMG_DIR / f"shap_dependence_single_linear_{_safe_stem(out_name)}_feat{feat_idx}_cbc_v2")
    plt.show()
    plt.close()

from itertools import combinations
corr_rows_single = []
for outA, outB in combinations(target_cols, 2):
    svA = single_shap_values[outA]
    svB = single_shap_values[outB]

    mean_shap_A = pd.Series(svA.mean(axis=0), index=clinical_feature_names_display)
    mean_shap_B = pd.Series(svB.mean(axis=0), index=clinical_feature_names_display)
    c = mean_shap_A.corr(mean_shap_B)

    corr_rows_single.append({
        "output_A_raw": outA,
        "output_A_display": target_display_map.get(outA, outA),
        "output_B_raw": outB,
        "output_B_display": target_display_map.get(outB, outB),
        "featurewise_SHAP_corr": float(c)
    })

corr_df_single = pd.DataFrame(corr_rows_single)
corr_df_single.to_csv(TAB_DIR / "cross_output_shap_single_linear_featurewise_corr_cbc_v2.csv", index=False)
print(corr_df_single)
print("Saved cross-output SHAP correlation CSV for single-output LINEAR models.")


# %% [markdown]
# ## 12) Output correlations: TRUE vs multi-output vs single-output (display names)

# %%
print("=== Computing correlations among HGB, HCT, RBC (v2) ===")

true_corr = pd.DataFrame(Y_test, columns=target_cols).corr()

preds_multi_std = multi_model.predict(X_test_scaled)
preds_multi = scaler_Y.inverse_transform(preds_multi_std)
preds_multi_df = pd.DataFrame(preds_multi, columns=target_cols)
multi_corr = preds_multi_df.corr()

preds_single_dict = {}
for k, out_name in enumerate(target_cols):
    model_k = single_models[out_name]
    pred_scaled = model_k.predict(X_test_scaled).reshape(-1)
    pred_real = pred_scaled * scaler_Y.scale_[k] + scaler_Y.mean_[k]
    preds_single_dict[out_name] = pred_real

preds_single_df = pd.DataFrame(preds_single_dict)
single_corr = preds_single_df.corr()

true_corr.to_csv(TAB_DIR / "corr_true_outputs_cbc_v2_test.csv", index_label="target_raw")
multi_corr.to_csv(TAB_DIR / "corr_multi_output_linear_preds_cbc_v2_test.csv", index_label="target_raw")
single_corr.to_csv(TAB_DIR / "corr_single_output_linear_preds_cbc_v2_test.csv", index_label="target_raw")

print("True correlation:\n", true_corr)
print("\nMulti-output LINEAR preds correlation:\n", multi_corr)
print("\nSingle-output LINEAR preds correlation:\n", single_corr)

def plot_corr_heatmap(mat, title, filename):
    plt.figure(figsize=(5, 4))
    plt.imshow(mat, cmap="coolwarm", vmin=-1, vmax=1)
    plt.colorbar()
    plt.xticks(range(len(target_display_names)), target_display_names)
    plt.yticks(range(len(target_display_names)), target_display_names)
    plt.title(title)
    plt.tight_layout()
    save_current_fig(IMG_DIR / filename)
    plt.show()
    plt.close()

plot_corr_heatmap(true_corr.values,
                  "Correlation — TRUE outputs (test, v2)",
                  "corr_heatmap_true_outputs_cbc_v2_test")

plot_corr_heatmap(multi_corr.values,
                  "Correlation — MULTI-OUTPUT LINEAR predictions (test, v2)",
                  "corr_heatmap_multi_output_linear_preds_cbc_v2_test")

plot_corr_heatmap(single_corr.values,
                  "Correlation — SINGLE-OUTPUT LINEAR predictions (test, v2)",
                  "corr_heatmap_single_output_linear_preds_cbc_v2_test")

print("\n=== Done! Correlation matrices and heatmaps saved successfully (LINEAR v2). ===")


# %% [markdown]
# ## 13) Save runtime & memory summary (MB)

# %%
runtime_df = pd.DataFrame(runtime_records)
csv_path = TAB_DIR / "runtime_memory_cbc_v2_MB_linear.csv"
runtime_df.to_csv(csv_path, index=False)

print("\n=== Runtime & Memory summary saved ===")
print("Path:", csv_path)
print(runtime_df)


# %% [markdown]
# ## 14) Evaluation metrics for Multi-output and Single-output Models (Train / Val / Test)

# %%
def eval_multi_split_all(split_name, X_scaled, Y_real):
    preds_scaled = multi_model.predict(X_scaled)
    preds_real = scaler_Y.inverse_transform(preds_scaled)

    rows = []
    for j, out_name in enumerate(target_cols):
        out_display = target_display_map.get(out_name, out_name)

        y_true = Y_real[:, j]
        y_pred = preds_real[:, j]

        r2   = r2_score(y_true, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true, y_pred))
        mae  = mean_absolute_error(y_true, y_pred)

        median_thr = np.median(y_true)
        y_true_bin = (y_true > median_thr).astype(int)
        y_pred_bin = (y_pred > median_thr).astype(int)
        f1 = f1_score(y_true_bin, y_pred_bin)

        rows.append({
            "model_type": "multi-output-linear",
            "split": split_name,
            "output": out_display,
            "R2": r2,
            "RMSE": rmse,
            "MAE": mae,
            "F1_median": f1,
        })
    return rows

metrics_multi_all = []
metrics_multi_all += eval_multi_split_all("train", X_train_scaled, Y_train)
metrics_multi_all += eval_multi_split_all("val",   X_val_scaled,   Y_val)
metrics_multi_all += eval_multi_split_all("test",  X_test_scaled,  Y_test)

df_metrics_multi_all = pd.DataFrame(metrics_multi_all)
df_metrics_multi_all.to_csv(TAB_DIR / "metrics_multi_output_linear_all_splits_cbc_v2.csv", index=False)

print("Multi-output LINEAR metrics (train/val/test):")
print(df_metrics_multi_all)


def eval_single_split_all(split_name, X_scaled, Y_real):
    rows = []
    for k, out_name in enumerate(target_cols):
        out_display = target_display_map.get(out_name, out_name)

        model_k = single_models[out_name]
        pred_scaled = model_k.predict(X_scaled).reshape(-1)
        pred_real = pred_scaled * scaler_Y.scale_[k] + scaler_Y.mean_[k]

        y_true_k = Y_real[:, k]

        r2_k   = r2_score(y_true_k, pred_real)
        rmse_k = np.sqrt(mean_squared_error(y_true_k, pred_real))
        mae_k  = mean_absolute_error(y_true_k, pred_real)

        median_thr = np.median(y_true_k)
        y_true_bin = (y_true_k > median_thr).astype(int)
        y_pred_bin = (pred_real > median_thr).astype(int)
        f1_k = f1_score(y_true_bin, y_pred_bin)

        rows.append({
            "model_type": "single-output-linear",
            "split": split_name,
            "output": out_display,
            "R2": r2_k,
            "RMSE": rmse_k,
            "MAE": mae_k,
            "F1_median": f1_k,
        })
    return rows

metrics_single_all = []
metrics_single_all += eval_single_split_all("train", X_train_scaled, Y_train)
metrics_single_all += eval_single_split_all("val",   X_val_scaled,   Y_val)
metrics_single_all += eval_single_split_all("test",  X_test_scaled,  Y_test)

df_metrics_single_all = pd.DataFrame(metrics_single_all)
df_metrics_single_all.to_csv(TAB_DIR / "metrics_single_output_linear_all_splits_cbc_v2.csv", index=False)

print("Single-output LINEAR metrics (train/val/test):")
print(df_metrics_single_all)


# %% [markdown]
# ## 15) Redesigned BAR comparison (Multi vs Single) per output (TopK)

# %%
# -----------------------------
# User-configurable parameters
# -----------------------------
TOPK = 10
COLOR_MULTI  = "#1f77b4"
COLOR_SINGLE = "#ff7f0e"

FS_TITLE = 18
FS_LABEL = 16
FS_TICK  = 16
FS_LEG   = 13

plt.rcParams.update({
    "axes.titlesize": FS_TITLE,
    "axes.labelsize": FS_LABEL,
    "xtick.labelsize": FS_TICK,
    "ytick.labelsize": FS_TICK,
    "legend.fontsize": FS_LEG,
})

def mean_abs_series(shap_array, feature_names_display):
    return pd.Series(np.mean(np.abs(shap_array), axis=0), index=feature_names_display)

for out_idx, out_name in enumerate(target_cols):
    out_disp = target_display_map.get(out_name, out_name)

    s_multi  = mean_abs_series(shap_values_per_output[out_idx], clinical_feature_names_display)
    s_single = mean_abs_series(single_shap_values[out_name],     clinical_feature_names_display)

    s_combined = (s_multi + s_single).sort_values(ascending=False)
    top_feats = s_combined.head(TOPK).index.tolist()

    df_bar = pd.DataFrame({
        "Multi-output":  s_multi.loc[top_feats].values,
        "Single-output": s_single.loc[top_feats].values,
    }, index=top_feats)

    plt.figure(figsize=(10, 6))
    ax = plt.gca()
    ax.set_facecolor("white")

    y = np.arange(len(top_feats))
    h = 0.38

    ax.barh(y - h/2, df_bar["Multi-output"],  height=h, color=COLOR_MULTI,  label="Multi-output")
    ax.barh(y + h/2, df_bar["Single-output"], height=h, color=COLOR_SINGLE, label="Single-output")

    ax.set_yticks(y)
    ax.set_yticklabels(top_feats, fontsize=FS_TICK)
    ax.invert_yaxis()

    ax.set_xlabel("Mean |SHAP value|", fontsize=FS_LABEL)
    ax.legend(loc="lower right", fontsize=FS_LEG)

    plt.tight_layout()
    save_current_fig(IMG_DIR / f"redesign_bar_compare_multi_vs_single_linear_{out_idx}_top{TOPK}_cbc_v2")
    plt.show()
    plt.close()

print(f"Saved redesigned BAR comparison plots (Top-{TOPK}) for all outputs (LINEAR).")


# %% [markdown]
# ## 16) Quantitative similarity of raw SHAP values (multi vs single)

# %%
from scipy.stats import spearmanr

def cosine_similarity(vec_a, vec_b, eps=1e-12):
    num = np.dot(vec_a, vec_b)
    den = (np.linalg.norm(vec_a) * np.linalg.norm(vec_b)) + eps
    return float(num / den)

similarity_rows = []

for k, out_name in enumerate(target_cols):
    out_display = target_display_map.get(out_name, out_name)

    shap_multi  = shap_values_per_output[k]
    shap_single = single_shap_values[out_name]

    if shap_multi.shape != shap_single.shape:
        raise ValueError(f"Shape mismatch for output {out_name}: multi {shap_multi.shape}, single {shap_single.shape}")

    vec_multi  = shap_multi.reshape(-1)
    vec_single = shap_single.reshape(-1)

    cos_sim = cosine_similarity(vec_multi, vec_single)
    spearman_corr, spearman_p = spearmanr(vec_multi, vec_single)

    similarity_rows.append({
        "output_raw": out_name,
        "output_display": out_display,
        "n_samples": shap_multi.shape[0],
        "n_features": shap_multi.shape[1],
        "cosine_similarity": cos_sim,
        "spearman_correlation": float(spearman_corr),
        "spearman_pvalue": float(spearman_p),
    })

similarity_df = pd.DataFrame(similarity_rows)
csv_path = TAB_DIR / "shap_local_similarity_multi_vs_single_linear_cbc_v2.csv"
similarity_df.to_csv(csv_path, index=False)

print("\n=== Local SHAP similarity (LINEAR multi vs single) — per output ===")
print(similarity_df)
print(f"\nSaved SHAP similarity metrics to:\n{csv_path}")

summary_df = pd.DataFrame({
    "metric": ["cosine_similarity", "spearman_correlation"],
    "mean": [
        similarity_df["cosine_similarity"].mean(),
        similarity_df["spearman_correlation"].mean(),
    ],
    "std": [
        similarity_df["cosine_similarity"].std(ddof=1),
        similarity_df["spearman_correlation"].std(ddof=1),
    ],
})
summary_path = TAB_DIR / "shap_local_similarity_summary_multi_vs_single_linear_cbc_v2.csv"
summary_df.to_csv(summary_path, index=False)

print("\n=== Summary across outputs (mean ± std) ===")
print(summary_df)
print(f"\nSaved summary to:\n{summary_path}")


# %% [markdown]
# ## 17) Local SHAP similarity restricted to Top-5 features (per output)

# %%
topK = 10
similarity_topk_rows = []

for k, out_name in enumerate(target_cols):
    out_display = target_display_map.get(out_name, out_name)

    shap_multi  = shap_values_per_output[k]
    shap_single = single_shap_values[out_name]

    mean_multi  = np.mean(np.abs(shap_multi), axis=0)
    mean_single = np.mean(np.abs(shap_single), axis=0)

    combined_importance = mean_multi + mean_single
    topk_idx = np.argsort(combined_importance)[::-1][:topK]
    topk_features_display = [clinical_feature_names_display[j] for j in topk_idx]

    shap_multi_topk  = shap_multi[:, topk_idx]
    shap_single_topk = shap_single[:, topk_idx]

    vec_multi  = shap_multi_topk.reshape(-1)
    vec_single = shap_single_topk.reshape(-1)

    cos_sim = cosine_similarity(vec_multi, vec_single)
    spearman_corr, spearman_p = spearmanr(vec_multi, vec_single)

    similarity_topk_rows.append({
        "output_raw": out_name,
        "output_display": out_display,
        "topK": topK,
        "topK_features_display": "; ".join(topk_features_display),
        "n_samples": shap_multi_topk.shape[0],
        "cosine_similarity_topK": cos_sim,
        "spearman_correlation_topK": float(spearman_corr),
        "spearman_pvalue_topK": float(spearman_p),
    })

similarity_topk_df = pd.DataFrame(similarity_topk_rows)
csv_path = TAB_DIR / "shap_local_similarity_top5_multi_vs_single_linear_cbc_v2.csv"
similarity_topk_df.to_csv(csv_path, index=False)

print("\n=== Local SHAP similarity (Top-5 features, LINEAR multi vs single) ===")
print(similarity_topk_df)
print(f"\nSaved Top-5 SHAP similarity metrics to:\n{csv_path}")

summary_topk_df = pd.DataFrame({
    "metric": ["cosine_similarity_topK", "spearman_correlation_topK"],
    "mean": [
        similarity_topk_df["cosine_similarity_topK"].mean(),
        similarity_topk_df["spearman_correlation_topK"].mean(),
    ],
    "std": [
        similarity_topk_df["cosine_similarity_topK"].std(ddof=1),
        similarity_topk_df["spearman_correlation_topK"].std(ddof=1),
    ],
})
summary_path = TAB_DIR / "shap_local_similarity_top5_summary_multi_vs_single_linear_cbc_v2.csv"
summary_topk_df.to_csv(summary_path, index=False)

print("\n=== Top-5 summary across outputs (mean ± std) ===")
print(summary_topk_df)
print(f"\nSaved Top-5 summary to:\n{summary_path}")


# %% [markdown]
# ## 18) Combined figure: Top-5 mean |SHAP| (Multi vs Single) for all outputs

# %%
print("\n[FIG18] Starting combined TopK mean|SHAP| figure...")

TOPK = 10
FS_TITLE = 16
FS_LABEL = 14
FS_TICK  = 12
FS_LEG   = 12
COLOR_MULTI  = "#1f77b4"
COLOR_SINGLE = "#ff7f0e"

def _mean_abs_series(shap_array, feature_names_display):
    return pd.Series(np.mean(np.abs(shap_array), axis=0), index=feature_names_display)

fig, axes = plt.subplots(1, len(target_cols), figsize=(20, 6), constrained_layout=True)
if len(target_cols) == 1:
    axes = [axes]

bar_width = 0.38

for out_idx, out_name in enumerate(target_cols):
    ax = axes[out_idx]
    out_disp = target_display_map.get(out_name, out_name)

    s_multi  = _mean_abs_series(shap_values_per_output[out_idx], clinical_feature_names_display)
    s_single = _mean_abs_series(single_shap_values[out_name],     clinical_feature_names_display)

    s_combined = (s_multi + s_single).sort_values(ascending=False)
    top_feats = s_combined.head(TOPK).index.tolist()

    y_multi  = s_multi.loc[top_feats].values
    y_single = s_single.loc[top_feats].values

    x = np.arange(len(top_feats))
    ax.bar(x - bar_width/2, y_multi,  width=bar_width, color=COLOR_MULTI,  label="Multi-output")
    ax.bar(x + bar_width/2, y_single, width=bar_width, color=COLOR_SINGLE, label="Single-output")

    ax.set_xticks(x)
    ax.set_xticklabels(top_feats, rotation=90, ha="center", fontsize=FS_TICK)

    if out_idx == 0:
        ax.set_ylabel("Mean |SHAP value|", fontsize=FS_LABEL)

    ax.set_title(out_disp, fontsize=FS_TITLE)
    ax.tick_params(axis="y", labelsize=FS_TICK)
    ax.grid(axis="y", alpha=0.25)

handles, labels = axes[0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center", ncol=2, fontsize=FS_LEG, frameon=True)

plt.subplots_adjust(bottom=0.30)

out_path = IMG_DIR / f"combined_top{TOPK}_bar_mean_abs_shap_multi_vs_single_all_outputs_linear_cbc_v2"
save_current_fig(out_path)

print(f"[FIG18] Saved figure to: {out_path.with_suffix('.png')} and {out_path.with_suffix('.pdf')}")
plt.show()
plt.close(fig)

print("[FIG18] Done.")
