## Systematic ablation analysis of modelling operations based on out-of-sample prediction error

In [None]:
# -*- coding: utf-8 -*-
"""
plot for CV_RMSE across model permutations (by age)
============================================================================
Input:
  For each age folder:
    FULL_EXHAUST_4EXPO_results_all_models.csv
Output:
  AGE_<age>/ED_FIG_CV_ops_<age>.png/.pdf

Notes:
- Pink: models INCLUDING the term/operation
- Grey: models EXCLUDING the term/operation
- Black bar: mean(without) - mean(with)  (positive => including improves RMSE)
- Red star: Welch t-test p < 0.05
"""

import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

from scipy.stats import gaussian_kde, ttest_ind

# =========================
# STYLE
# =========================
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
plt.rcParams["font.sans-serif"] = ["Arial", "SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# =========================
# PATHS
# =========================
ROOT = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\FULL_EXHAUST_4EXPO_BY_AGE"
AGE_DIRS = {
    "total": os.path.join(ROOT, "AGE_total"),
    "u5":    os.path.join(ROOT, "AGE_u5"),
    "5_65":  os.path.join(ROOT, "AGE_5_65"),
    "65p":   os.path.join(ROOT, "AGE_65p"),
}

CSV_NAME = "FULL_EXHAUST_4EXPO_results_all_models.csv"
RMSE_COL = "CV_RMSE"

# KDE 
KDE_N = 400


MIN_N = 25

# =========================
# HELPERS
# =========================
def _as_set(x):
    if pd.isna(x):
        return set()
    s = str(x).strip()
    if s == "" or s.upper() == "NO_INTER":
        return set()
    # ‰Ω†ÁöÑ interactions Â≠óÊÆµÊòØ "HAP_X+CLIM1_X" ËøôÁßç
    return set([t.strip() for t in s.split("+") if t.strip()])

def kde_curve(values, xgrid):
    values = np.asarray(values, dtype=float)
    values = values[np.isfinite(values)]
    if values.size < 5:
        return np.zeros_like(xgrid)
    # Èò≤Ê≠¢ÂÖ®Áõ∏ÂêåÂØºËá¥ KDE Â¥©Ê∫É
    if np.nanstd(values) < 1e-12:
        y = np.zeros_like(xgrid)
        # Âú®ËØ•ÁÇπÈôÑËøëÂÅö‰∏Ä‰∏™Á™ÑÁöÑ‚ÄúËÑâÂÜ≤‚Äù
        j = np.argmin(np.abs(xgrid - float(values[0])))
        y[max(0, j-1):min(len(y), j+2)] = 1.0
        return y
    kde = gaussian_kde(values)
    y = kde(xgrid)
    return y

def welch_p(a, b):
    a = np.asarray(a, dtype=float); b = np.asarray(b, dtype=float)
    a = a[np.isfinite(a)]; b = b[np.isfinite(b)]
    if a.size < MIN_N or b.size < MIN_N:
        return np.nan
    try:
        return float(ttest_ind(a, b, equal_var=False).pvalue)
    except Exception:
        return np.nan

def mean_diff(with_vals, without_vals):
    with_vals = np.asarray(with_vals, dtype=float)
    without_vals = np.asarray(without_vals, dtype=float)
    with_vals = with_vals[np.isfinite(with_vals)]
    without_vals = without_vals[np.isfinite(without_vals)]
    if with_vals.size < 5 or without_vals.size < 5:
        return np.nan
    return float(np.nanmean(without_vals) - np.nanmean(with_vals))

# =========================
# TERM DEFINITIONS (operations)
# =========================
def build_terms(df):
    """
    ËøîÂõû list[ (label, mask_func) ]
    mask_func(df) -> boolean array, True Ë°®Á§∫‚ÄúÂåÖÂê´ËØ•Êìç‰Ωú‚Äù
    """
    # ÂÖºÂÆπÂàóÂêçÂèØËÉΩÁº∫Â§±Ôºömask_func ÂÜÖÈÉ®Ë¶ÅÂÆâÂÖ®
    def col_eq(col, val):
        if col not in df.columns:
            return np.zeros(len(df), dtype=bool)
        return (df[col].astype(str) == str(val)).to_numpy()

    def has_inter(term):
        if "interactions" not in df.columns:
            return np.zeros(len(df), dtype=bool)
        return df["interactions"].apply(lambda x: term in _as_set(x)).to_numpy()

    terms = [
        ("HAP log1p",          lambda d: col_eq("hap", "hap_log1p")),
        ("PM2.5 log1p",        lambda d: col_eq("pm25", "pm25_log1p")),
        ("Z-score all vars",   lambda d: col_eq("zall", "zall")),
        ("Climate: AH resid | TAVG", lambda d: col_eq("clim_struct", "both_AHresid_on_TAVG")),
        ("Climate: TAVG resid | AH", lambda d: col_eq("clim_struct", "both_TAVGresid_on_AH")),
        ("SDI scheme: Q4 merge top", lambda d: col_eq("sdi_scheme", "Q4_mergeTop")),
        ("Coding: star_expand",       lambda d: col_eq("coding", "star_expand")),
        ("Center interaction within SDI group", lambda d: col_eq("center_inter", "centerWithinQ")),
        # interactions presence
        ("Interaction includes HAP",   lambda d: has_inter("HAP_X")),
        ("Interaction includes PM2.5", lambda d: has_inter("PM25_X")),
        ("Interaction includes CLIM1", lambda d: has_inter("CLIM1_X")),
        ("Interaction includes CLIM2", lambda d: has_inter("CLIM2_X")),
    ]
    return terms

# =========================
# PLOT ONE AGE
# =========================
def plot_one_age(age, csv_fp, out_dir):
    df = pd.read_csv(csv_fp)

    if RMSE_COL not in df.columns:
        raise ValueError(f"Missing {RMSE_COL} in {csv_fp}")

    df[RMSE_COL] = pd.to_numeric(df[RMSE_COL], errors="coerce")
    df = df[np.isfinite(df[RMSE_COL])].copy()

    if len(df) == 0:
        raise RuntimeError(f"No finite RMSE rows for age={age}")

    # x range for plots
    x_min = 0.2
    x_max = 1
    if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
        x_min = float(np.nanmin(df[RMSE_COL]))
        x_max = float(np.nanmax(df[RMSE_COL]))
    pad = 0.05 * (x_max - x_min + 1e-12)
    x_min -= pad; x_max += pad
    xgrid = np.linspace(x_min, x_max, KDE_N)

    terms = build_terms(df)

    # rows: 1 unconditional + n_terms
    n_rows = 1 + len(terms)

    fig_h = 6 * n_rows  # Ëá™ÈÄÇÂ∫îÈ´òÂ∫¶
    fig_w = 8
    fig, axes = plt.subplots(n_rows, 1, figsize=(fig_w, fig_h), sharex=True)

    # ----- Row 0: unconditional density
    ax0 = axes[0]
    y0 = kde_curve(df[RMSE_COL].to_numpy(), xgrid)
    ax0.fill_between(xgrid, 0, y0, alpha=0.25)
    ax0.plot(xgrid, y0, lw=1.5)
    ax0.set_ylabel("Density")
    ax0.set_title(f"{age}: Unconditional OOS RMSE density (all models)", fontsize=13)
    COLOR_WITH    = "#E64B35"   # ÊüîÂíåÁ∫¢ / Á≤â
    COLOR_WITHOUT = "#4D4D4D"   # Ê∑±ÁÅ∞

    # ----- Other rows: conditional densities
    for i, (label, mask_func) in enumerate(terms, start=1):
        ax = axes[i]
        m = mask_func(df)

        with_vals = df.loc[m, RMSE_COL].to_numpy()
        wo_vals   = df.loc[~m, RMSE_COL].to_numpy()

        y_with = kde_curve(with_vals, xgrid)
        y_wo   = kde_curve(wo_vals, xgrid)

        # grey = without
        ax.fill_between(xgrid, 0, y_wo, color=COLOR_WITHOUT, alpha=0.30)
        ax.plot(xgrid, y_wo, color=COLOR_WITHOUT, lw=1.0)
        
        # pink = with
        ax.fill_between(xgrid, 0, y_with, color=COLOR_WITH, alpha=0.30)
        ax.plot(xgrid, y_with, color=COLOR_WITH, lw=1.0)


        # labels & small stats
        n_with = int(np.isfinite(with_vals).sum())
        n_wo   = int(np.isfinite(wo_vals).sum())

        dmean = mean_diff(with_vals, wo_vals)  # mean(without) - mean(with)
        pval = welch_p(with_vals, wo_vals)

        # Âú®Ë°åÂÜÖÂ∑¶‰æßÂÜô‰ø°ÊÅØ
        txt = f"{label}  |  with={n_with}  without={n_wo}"
        ax.text(0.01, 0.78, txt, transform=ax.transAxes, fontsize=10)

        # ÈªëÁ∫øÔºöÂπ≥ÂùáÂ∑ÆÔºàÁî® x ËΩ¥‰ΩçÁΩÆË°®ËææÔºâ
        # ÁîªÂú®ÂØÜÂ∫¶ÂõæÂ∫ïÈÉ®ÈôÑËøëÔºö‰ªé mean_with Âà∞ mean_wo
        if np.isfinite(dmean) and n_with >= 5 and n_wo >= 5:
            mu_with = float(np.nanmean(with_vals))
            mu_wo   = float(np.nanmean(wo_vals))
            ybar = 0.02 * max(np.nanmax(y_with), np.nanmax(y_wo), 1e-6)
            # ax.plot([mu_with, mu_wo], [ybar, ybar], lw=2.2, color="k")

            # Á∫¢ÊòüÔºöÊòæËëóÊÄß
            # if np.isfinite(pval) and pval < 0.05:
            #     xm = 0.5 * (mu_with + mu_wo)
            #     ax.text(xm, ybar * 1.6, "‚òÖ", color="red", ha="center", va="bottom", fontsize=14)

            # Âú®Âè≥‰æßÊ†áÊ≥®Â∑ÆÂÄº
            ax.text(0.99, 0.78,
                    f"Œîmean(without-with)={dmean:+.3f} | p={pval:.3g}" if np.isfinite(pval) else f"Œîmean={dmean:+.3f} | p=NA",
                    transform=ax.transAxes, ha="right", fontsize=10)

        ax.set_yticks([])
        ax.set_ylabel("")

    axes[-1].set_xlabel("Out-of-sample RMSE (CV_RMSE)")

    # legend (ÁÆÄÂåñ)
    # Áî®‰∏§Êù°ËôöÊãüpatchËØ¥ÊòéÈ¢úËâ≤Âê´‰πâ
    from matplotlib.patches import Patch
    handles = [
        Patch(facecolor=COLOR_WITHOUT, alpha=0.30, label="without term"),
        Patch(facecolor=COLOR_WITH,    alpha=0.30, label="with term"),
    ]

    # ÊîæÂú®È°∂ÈÉ®Âè≥‰æß
    axes[0].legend(handles=handles, loc="upper right", frameon=False)

    plt.tight_layout()

    out_png = os.path.join(out_dir, f"ED_FIG_CV_ops_{age}.png")
    out_pdf = os.path.join(out_dir, f"ED_FIG_CV_ops_{age}.pdf")
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.savefig(out_pdf, bbox_inches="tight")
    plt.show()
    plt.close()

    print(f"[Saved] {out_png}")
    print(f"[Saved] {out_pdf}")

# =========================
# RUN ALL AGES
# =========================
def main():
    for age, d in AGE_DIRS.items():
        csv_fp = os.path.join(d, CSV_NAME)
        if not os.path.exists(csv_fp):
            print("[Skip] missing:", csv_fp)
            continue
        print("\n" + "="*90)
        print("AGE:", age)
        print("CSV:", csv_fp)
        print("="*90)
        plot_one_age(age, csv_fp, d)

if __name__ == "__main__":
    main()


## Theory-constrained exhaustive structural model search with stability filtering

In [None]:
# -*- coding: utf-8 -*-
"""
This procedure does not aim to identify the statistically best-fitting model,
but rather to select a stable and interpretable structural representation consistent with prior theoretical considerations.
"""

import os
import warnings
import itertools
import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt

import statsmodels.formula.api as smf
import statsmodels.api as sm
from statsmodels.stats.outliers_influence import variance_inflation_factor

warnings.filterwarnings("ignore")

# =========================
# STYLE
# =========================
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
plt.rcParams["font.sans-serif"] = ["Arial", "SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# =========================
# PATHS
# =========================
IN_FP = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\panel0_1990_2019_direct_meteo_GBDPM_HAP_lui.csv"
OUT_ROOT = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\FULL_EXHAUST_4EXPO_BY_AGE"
os.makedirs(OUT_ROOT, exist_ok=True)

# =========================
# CONFIG
# =========================
Y0, Y1 = 1990, 2019
CV_K   = 5
SEED   = 123
SE_TYPE = "HC1"
VIF_CAP = 20.0  # Á®≥ÂÆöÊÄßÈòàÂÄºÔºöÂª∫ËÆÆ20ÔºõÂ§™‰∏•Ê†ºÂèØË∞É30

# ‚úÖÊñ∞Â¢ûÔºöÂõõ‰∏™‰∏ªÊïàÂ∫îÈÉΩÂøÖÈ°ªÊòæËëó
P_MAIN_CAP = 0.05  # ÂèØË∞ÉÔºö0.05/0.10

AGE_SPECS = {
    "total": {"y": "uri_total", "pop": "pop_total"},
    "u5":    {"y": "uri_u5",    "pop": "pop_u5"},
    "5_65":  {"y": "uri_5_65",  "pop": "pop_5_65"},
    "65p":   {"y": "uri_65p",   "pop": "pop_65p"},
}

# Êö¥Èú≤ÂàóÂêçÔºàÈù¢ÊùøÊñá‰ª∂‰∏≠Ôºâ
HAP_COL  = "hap_pm_pw"
PM25_COL = "pm25_pw"
TAVG_COL = "tavg_pw_C"
AH_COL   = "ah_pw"
SDI_COL  = "sdi"
DENS_COL = "density_total_pkm2"


# =========================
# HELPERS
# =========================
def require_cols(df, cols):
    miss = [c for c in cols if c not in df.columns]
    if miss:
        raise ValueError(f"Missing columns: {miss}")

def safe_log(x, floor=1e-12):
    return np.log(np.clip(x, floor, None))

def safe_log1p(x):
    return np.log1p(np.clip(x, 0.0, None))

def zscore(s):
    s = pd.to_numeric(s, errors="coerce")
    mu = np.nanmean(s)
    sd = np.nanstd(s)
    if not np.isfinite(sd) or sd <= 1e-12:
        sd = 1.0
    return (s - mu) / sd

def kfold_pos_indices(n, k=5, seed=123):
    rng = np.random.default_rng(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    return np.array_split(idx, k)

def compute_vif_exog(exog, names):
    X = pd.DataFrame(exog, columns=names)
    if "Intercept" in X.columns:
        X = X.drop(columns=["Intercept"])
    keep = [c for c in X.columns if np.nanstd(X[c].to_numpy()) > 1e-12]
    X = X[keep].copy()
    if X.shape[1] <= 1:
        return pd.DataFrame({"term": X.columns, "VIF": np.nan})
    X = X.fillna(X.mean(numeric_only=True))
    vifs = []
    for i, c in enumerate(X.columns):
        try:
            v = float(variance_inflation_factor(X.values, i))
        except Exception:
            v = np.nan
        vifs.append((c, v))
    return pd.DataFrame(vifs, columns=["term","VIF"]).sort_values("VIF", ascending=False)

def make_bins_train_only(train_sdi, q=5):
    train_sdi = np.asarray(train_sdi, dtype=float)
    train_sdi = train_sdi[np.isfinite(train_sdi)]
    if train_sdi.size < max(20, q*4):
        return None
    qs = np.nanquantile(train_sdi, np.linspace(0, 1, q+1))
    qs2 = qs.copy()
    for i in range(1, len(qs2)):
        if qs2[i] <= qs2[i-1]:
            qs2[i] = qs2[i-1] + 1e-12
    qs2[0]  = -np.inf
    qs2[-1] = np.inf
    return qs2

def assign_q_from_bins(sdi, bins, labels):
    cat = pd.cut(sdi, bins=bins, labels=labels, include_lowest=True)
    return cat.astype("category")

def merge_top(cat):
    s = cat.astype(str).replace({"Q4":"TOP","Q5":"TOP"})
    return pd.Categorical(s, categories=["Q1","Q2","Q3","TOP"], ordered=True)

def add_sdi_q_foldwise(d_raw, q=5, mergeTop=False, train_pos=None):
    """
    foldÂÜÖÊåâ train_pos ËÆ°ÁÆóÂàÜ‰ΩçÈòàÂÄºÔºåÊò†Â∞ÑÂà∞ÂÖ®‰ΩìÔºõÈÅøÂÖçCVÊ≥ÑÊºè„ÄÇ
    d_raw ÈúÄÂåÖÂê´ rid=0..n-1
    """
    dd = d_raw.copy()
    labels = [f"Q{i}" for i in range(1, q+1)]
    if train_pos is None:
        bins = make_bins_train_only(dd["sdi_mean"].to_numpy(), q=q)
        if bins is None:
            dd["SDI_Q"] = pd.qcut(dd["sdi_mean"], q=q, labels=labels).astype("category")
        else:
            dd["SDI_Q"] = assign_q_from_bins(dd["sdi_mean"], bins, labels)
    else:
        train_sdi = dd.iloc[train_pos]["sdi_mean"].to_numpy()
        bins = make_bins_train_only(train_sdi, q=q)
        if bins is None:
            bins = make_bins_train_only(dd["sdi_mean"].to_numpy(), q=q)
        dd["SDI_Q"] = assign_q_from_bins(dd["sdi_mean"], bins, labels)

    dd = dd.dropna(subset=["SDI_Q"]).copy()
    if mergeTop:
        dd["SDI_Q"] = merge_top(dd["SDI_Q"])
        dd = dd.dropna(subset=["SDI_Q"]).copy()
    return dd

def add_center_within_q(dd, exposures, center_inter):
    d = dd.copy()
    if not center_inter:
        return d
    for x in exposures:
        d[f"{x}_CQ"] = d[x] - d.groupby("SDI_Q")[x].transform("mean")
    return d

def build_formula(exposures, controls, inter_set, modifier_only=True, center_inter=False):
    rhs = []
    for x in exposures:
        if x in inter_set:
            x_int = f"{x}_CQ" if center_inter else x
            if modifier_only:
                rhs.append(f"{x} + {x_int}:C(SDI_Q)")
            else:
                rhs.append(f"{x} * C(SDI_Q)")
        else:
            rhs.append(f"{x}")
    rhs += controls
    return "log_lui_rate_mean ~ " + " + ".join(rhs)

def inter_terms(inter_set, cats, center_inter):
    terms = []
    for x in inter_set:
        x_int = f"{x}_CQ" if center_inter else x
        for g in cats[1:]:
            terms.append(f"{x_int}:C(SDI_Q)[T.{g}]")
    return terms

def wald_pvalue(res, terms):
    if not terms:
        return np.nan
    params = res.params.index.tolist()
    keep = [t for t in terms if t in params]
    if not keep:
        return np.nan
    R = np.zeros((len(keep), len(params)))
    for i, t in enumerate(keep):
        R[i, params.index(t)] = 1.0
    try:
        wt = res.wald_test(R)
        return float(np.asarray(wt.pvalue).ravel()[0])
    except Exception:
        return np.nan

# ‚úÖÊñ∞Â¢ûÔºöÂõõ‰∏ªÊïàÂ∫îÊòæËëóÊÄßÂà§Âà´
def main4_pvals_all_sig(res, exposures, p_cap=0.05):
    """
    Ê£ÄÊü•Âõõ‰∏™Êö¥Èú≤‰∏ªÊïàÂ∫îÔºàHAP_X, PM25_X, CLIM1_X, CLIM2_XÔºâÊòØÂê¶ÈÉΩÊòæËëó„ÄÇ
    ËøîÂõûÔºöall_sig(bool), pvals(dict)
    """
    p = res.pvalues
    pvals = {}
    ok = True
    for x in exposures:
        px = float(p.get(x, np.nan))
        pvals[x] = px
        if (not np.isfinite(px)) or (px >= p_cap):
            ok = False
    return ok, pvals

def cv_rmse(formula, d_raw, exposures, q=5, mergeTop=False, center_inter=False, k=5, seed=123):
    y = d_raw["log_lui_rate_mean"].to_numpy()
    folds = kfold_pos_indices(len(d_raw), k=k, seed=seed)
    preds = np.full_like(y, np.nan, dtype=float)

    all_pos = np.arange(len(d_raw), dtype=int)

    for test_pos in folds:
        train_pos = np.setdiff1d(all_pos, test_pos)

        dd_full = add_sdi_q_foldwise(d_raw, q=q, mergeTop=mergeTop, train_pos=train_pos)

        train = dd_full[dd_full["rid"].isin(train_pos)].copy()
        test  = dd_full[dd_full["rid"].isin(test_pos)].copy()

        if len(train) < 30 or len(test) < 5:
            continue

        train = add_center_within_q(train, exposures, center_inter)
        test  = add_center_within_q(test, exposures, center_inter)

        try:
            m = smf.ols(formula, data=train).fit()
            pr = m.predict(test)
            preds[test["rid"].to_numpy().astype(int)] = pr
        except Exception:
            continue

    ok = np.isfinite(preds)
    if ok.sum() < max(10, len(d_raw)//3):
        return np.nan
    return float(np.sqrt(np.mean((preds[ok] - y[ok])**2)))

# =========================
# BUILD DATASET (4 EXPO ALWAYS)
# =========================
def build_dataset_4expo(df_mean, hap_log1p, pm25_log1p, zall, clim_struct):
    d = df_mean.copy()

    d["HAP_X"]  = safe_log1p(d["hap_mean"])  if hap_log1p  else d["hap_mean"]
    d["PM25_X"] = safe_log1p(d["pm25_mean"]) if pm25_log1p else d["pm25_mean"]
    d["TAVG_X"] = d["tavg_mean"]
    d["AH_X"]   = d["ah_mean"]
    d["DENS_X"] = safe_log1p(d["dens_mean"])

    if clim_struct == "both_raw":
        d["CLIM1_X"] = d["TAVG_X"]
        d["CLIM2_X"] = d["AH_X"]
    elif clim_struct == "both_AHresid_on_TAVG":
        tmp = d.dropna(subset=["AH_X","TAVG_X"]).copy()
        m = smf.ols("AH_X ~ TAVG_X", data=tmp).fit()
        d["AH_resid"] = d["AH_X"] - m.predict(d)
        d["CLIM1_X"] = d["TAVG_X"]
        d["CLIM2_X"] = d["AH_resid"]
    elif clim_struct == "both_TAVGresid_on_AH":
        tmp = d.dropna(subset=["AH_X","TAVG_X"]).copy()
        m = smf.ols("TAVG_X ~ AH_X", data=tmp).fit()
        d["TAVG_resid"] = d["TAVG_X"] - m.predict(d)
        d["CLIM1_X"] = d["AH_X"]
        d["CLIM2_X"] = d["TAVG_resid"]
    else:
        raise ValueError("Unknown clim_struct")

    if zall:
        for c in ["HAP_X","PM25_X","CLIM1_X","CLIM2_X","DENS_X"]:
            d[c] = zscore(d[c])

    need = ["log_lui_rate_mean","sdi_mean","HAP_X","PM25_X","CLIM1_X","CLIM2_X","DENS_X"]
    d = d.dropna(subset=need).copy()

    exposures = ["HAP_X","PM25_X","CLIM1_X","CLIM2_X"]  # Ê∞∏Ëøú4Êö¥Èú≤
    controls  = ["DENS_X"]
    return d, exposures, controls


# =========================
# ONE-AGE EXHAUST RUN
# =========================
def run_exhaust_one_age(df_mean, out_dir, vif_cap=20.0, verbose=True):
    os.makedirs(out_dir, exist_ok=True)

    # ËÆæËÆ°Á©∫Èó¥
    PIPE_HAP  = [("hap_raw", False), ("hap_log1p", True)]
    PIPE_PM25 = [("pm25_raw", False), ("pm25_log1p", True)]
    PIPE_ZALL = [("noz", False), ("zall", True)]

    CLIM_STRUCTS = ["both_raw", "both_AHresid_on_TAVG", "both_TAVGresid_on_AH"]

    SDI_SCHEMES = [("Q5", 5, False), ("Q4_mergeTop", 5, True)]
    CODINGS = [("modifier_only", True), ("star_expand", False)]
    CENTER_OPTS = [("nocenter", False), ("centerWithinQ", True)]

    # ‰øùÂ≠ò base Âø´ÁÖß
    base_fp = os.path.join(out_dir, "iso3_longterm_means_base.csv")
    df_mean.to_csv(base_fp, index=False, encoding="utf-8-sig")
    if verbose:
        print("Saved:", base_fp, "| n_iso3:", df_mean["iso3"].nunique())

    N_ISO = int(df_mean["iso3"].nunique())

    ALL = []
    job_id = 0

    for hap_name, hap_log1p in PIPE_HAP:
        for pm_name, pm25_log1p in PIPE_PM25:
            for z_name, zall in PIPE_ZALL:
                for clim_struct in CLIM_STRUCTS:

                    d_raw, exposures, controls = build_dataset_4expo(df_mean, hap_log1p, pm25_log1p, zall, clim_struct)
                    d_raw = d_raw.reset_index(drop=True).copy()
                    d_raw["rid"] = np.arange(len(d_raw), dtype=int)

                    for sdi_scheme, q, mergeTop in SDI_SCHEMES:
                        d_fit = add_sdi_q_foldwise(d_raw, q=q, mergeTop=mergeTop, train_pos=None)

                        for coding_name, modifier_only in CODINGS:
                            for center_name, center_inter in CENTER_OPTS:
                                d0 = add_center_within_q(d_fit, exposures, center_inter)

                                try:
                                    cats = list(d0["SDI_Q"].cat.categories)
                                except Exception:
                                    cats = None

                                for r in range(0, len(exposures)+1):
                                    for comb in itertools.combinations(exposures, r):
                                        inter_set = set(comb)
                                        tag = "+".join(sorted(inter_set)) if inter_set else "NO_INTER"
                                        formula = build_formula(exposures, controls, inter_set, modifier_only, center_inter)

                                        job_id += 1

                                        aic=bic=adjr2=cv=waldp=max_vif=np.nan
                                        fit_ok=0
                                        err_main=err_cv=err_vif=err_wald=""
                                        main4_all_sig = 0
                                        p_HAP=p_PM25=p_C1=p_C2 = np.nan

                                        # main fit
                                        try:
                                            res = smf.ols(formula, data=d0).fit(cov_type=SE_TYPE)
                                            fit_ok = 1
                                            aic = float(res.aic) if np.isfinite(res.aic) else np.nan
                                            bic = float(res.bic) if np.isfinite(res.bic) else np.nan
                                            adjr2 = float(res.rsquared_adj) if np.isfinite(res.rsquared_adj) else np.nan

                                            # ‚úÖÊñ∞Â¢ûÔºöÂõõ‰∏ªÊïàÂ∫îÊòæËëóÊÄß
                                            all_sig, p_main = main4_pvals_all_sig(res, exposures, p_cap=P_MAIN_CAP)
                                            main4_all_sig = int(all_sig)
                                            p_HAP  = p_main.get("HAP_X", np.nan)
                                            p_PM25 = p_main.get("PM25_X", np.nan)
                                            p_C1   = p_main.get("CLIM1_X", np.nan)
                                            p_C2   = p_main.get("CLIM2_X", np.nan)

                                        except Exception as e:
                                            err_main = str(e)[:200]
                                            ALL.append({
                                                "job_id": job_id,
                                                "hap": hap_name, "pm25": pm_name, "zall": z_name,
                                                "clim_struct": clim_struct,
                                                "sdi_scheme": sdi_scheme,
                                                "coding": coding_name,
                                                "center_inter": center_name,
                                                "interactions": tag,
                                                "n_fit": int(len(d0)),
                                                "formula": formula,
                                                "AIC": aic, "BIC": bic, "AdjR2": adjr2,
                                                "CV_RMSE": cv, "WaldP_inter": waldp, "maxVIF": max_vif,
                                                "main4_all_sig": int(main4_all_sig),
                                                "p_HAP_main": p_HAP,
                                                "p_PM25_main": p_PM25,
                                                "p_CLIM1_main": p_C1,
                                                "p_CLIM2_main": p_C2,
                                                "fit_ok": fit_ok,
                                                "err_main": err_main, "err_cv": err_cv, "err_vif": err_vif, "err_wald": err_wald,
                                            })
                                            continue

                                        # CV
                                        try:
                                            cv = cv_rmse(formula, d_raw, exposures, q=q, mergeTop=mergeTop,
                                                         center_inter=center_inter, k=CV_K, seed=SEED)
                                        except Exception as e:
                                            err_cv = str(e)[:200]
                                            cv = np.nan

                                        # Wald
                                        try:
                                            if cats is not None:
                                                terms = inter_terms(inter_set, cats, center_inter)
                                                waldp = wald_pvalue(res, terms)
                                        except Exception as e:
                                            err_wald = str(e)[:200]
                                            waldp = np.nan

                                        # VIF
                                        try:
                                            vif_df = compute_vif_exog(res.model.exog, res.model.exog_names)
                                            max_vif = float(np.nanmax(vif_df["VIF"].to_numpy())) if len(vif_df) else np.nan
                                        except Exception as e:
                                            err_vif = str(e)[:200]
                                            max_vif = np.nan

                                        ALL.append({
                                            "job_id": job_id,
                                            "hap": hap_name, "pm25": pm_name, "zall": z_name,
                                            "clim_struct": clim_struct,
                                            "sdi_scheme": sdi_scheme,
                                            "coding": coding_name,
                                            "center_inter": center_name,
                                            "interactions": tag,
                                            "n_fit": int(len(d0)),
                                            "formula": formula,
                                            "AIC": aic, "BIC": bic, "AdjR2": adjr2,
                                            "CV_RMSE": cv, "WaldP_inter": waldp, "maxVIF": max_vif,
                                            "main4_all_sig": int(main4_all_sig),
                                            "p_HAP_main": p_HAP,
                                            "p_PM25_main": p_PM25,
                                            "p_CLIM1_main": p_C1,
                                            "p_CLIM2_main": p_C2,
                                            "fit_ok": fit_ok,
                                            "err_main": err_main, "err_cv": err_cv, "err_vif": err_vif, "err_wald": err_wald,
                                        })

    df_res = pd.DataFrame(ALL)
    df_res["AdjR2"] = pd.to_numeric(df_res["AdjR2"], errors="coerce")
    df_res["AdjR2_neg"] = -df_res["AdjR2"]
    for c in ["CV_RMSE","BIC","AdjR2_neg","AIC","maxVIF","WaldP_inter",
              "p_HAP_main","p_PM25_main","p_CLIM1_main","p_CLIM2_main","main4_all_sig"]:
        if c in df_res.columns:
            df_res[c] = pd.to_numeric(df_res[c], errors="coerce")

    res_fp = os.path.join(out_dir, "FULL_EXHAUST_4EXPO_results_all_models.csv")
    df_res.to_csv(res_fp, index=False, encoding="utf-8-sig")
    if verbose:
        print("Saved:", res_fp)
        print("Total jobs:", len(df_res), "| fit_ok:", int(df_res["fit_ok"].sum()))

    # pick best (stable)
    df_ok = df_res[df_res["fit_ok"]==1].copy()
    df_ok = df_ok.dropna(subset=["CV_RMSE","BIC","AdjR2_neg"], how="any").copy()
    df_ok = df_ok[df_ok["n_fit"] == N_ISO].copy()

    # ‚úÖÊñ∞Â¢ûÔºöÂõõ‰∏ªÊïàÂ∫îÂøÖÈ°ªÊòæËëó
    df_ok = df_ok[df_ok["main4_all_sig"] == 1].copy()

    df_stable = df_ok[np.isfinite(df_ok["maxVIF"]) & (df_ok["maxVIF"] <= float(vif_cap))].copy()
    if len(df_stable) == 0:
        # ÊîæÂÆΩÂà∞50ÔºåÈÅøÂÖçÊó†Ëß£
        df_stable = df_ok[np.isfinite(df_ok["maxVIF"]) & (df_ok["maxVIF"] <= 50)].copy()

    if len(df_stable) == 0:
        raise RuntimeError(
            "No stable model found after requiring 4-main p<cap. "
            "Try relax P_MAIN_CAP (e.g., 0.10) or relax VIF_CAP, or check data."
        )

    df_stable = df_stable.sort_values(["CV_RMSE","BIC","AdjR2_neg"], ascending=[True, True, True]).copy()
    best = df_stable.iloc[0].to_dict()

    best_fp = os.path.join(out_dir, "BEST_model_row.csv")
    pd.DataFrame([best]).to_csv(best_fp, index=False, encoding="utf-8-sig")
    if verbose:
        print("Saved:", best_fp)
        print("[BEST MODEL (stable, main4 p<cap)]")
        for k in ["hap","pm25","zall","clim_struct","sdi_scheme","coding","center_inter","interactions",
                  "CV_RMSE","BIC","AdjR2","WaldP_inter","maxVIF","n_fit",
                  "main4_all_sig","p_HAP_main","p_PM25_main","p_CLIM1_main","p_CLIM2_main",
                  "formula"]:
            print(f"  {k}: {best.get(k)}")

    # refit best to export summary/coef/vif and slope plot
    hap_log1p = (best["hap"] == "hap_log1p")
    pm25_log1p = (best["pm25"] == "pm25_log1p")
    zall = (best["zall"] == "zall")
    clim_struct = best["clim_struct"]
    mergeTop = (best["sdi_scheme"] == "Q4_mergeTop")
    center_inter = (best["center_inter"] == "centerWithinQ")

    d_raw_best, exposures, controls = build_dataset_4expo(df_mean, hap_log1p, pm25_log1p, zall, clim_struct)
    d_raw_best = d_raw_best.reset_index(drop=True).copy()
    d_raw_best["rid"] = np.arange(len(d_raw_best), dtype=int)

    d_fit_best = add_sdi_q_foldwise(d_raw_best, q=5, mergeTop=mergeTop, train_pos=None)
    d_fit_best = add_center_within_q(d_fit_best, exposures, center_inter)

    best_formula = best["formula"]
    res_best = smf.ols(best_formula, data=d_fit_best).fit(cov_type=SE_TYPE)

    # summary
    summ_fp = os.path.join(out_dir, "BEST_model_summary.txt")
    with open(summ_fp, "w", encoding="utf-8") as f:
        f.write("BEST FORMULA:\n" + best_formula + "\n\n")
        f.write(f"\n[Main-4 p-value threshold] P_MAIN_CAP={P_MAIN_CAP}\n")
        f.write("Main-4 p-values:\n")
        for x in exposures:
            f.write(f"  {x}: {float(res_best.pvalues.get(x, np.nan))}\n")
        f.write("\n")
        f.write(res_best.summary().as_text())
    if verbose:
        print("Saved:", summ_fp)

    # coef
    coef = pd.DataFrame({
        "term": res_best.params.index,
        "beta": res_best.params.values,
        "se": res_best.bse.values,
        "p": res_best.pvalues.values
    }).sort_values("p", ascending=True)

    coef_fp = os.path.join(out_dir, "BEST_model_coef.csv")
    coef.to_csv(coef_fp, index=False, encoding="utf-8-sig")
    if verbose:
        print("Saved:", coef_fp)

    # VIF
    try:
        vif_df = compute_vif_exog(res_best.model.exog, res_best.model.exog_names)
        vif_fp = os.path.join(out_dir, "BEST_model_VIF.csv")
        vif_df.to_csv(vif_fp, index=False, encoding="utf-8-sig")
        if verbose:
            print("Saved:", vif_fp)
    except Exception as e:
        if verbose:
            print("VIF failed:", str(e)[:200])

    # slope plot for interacted exposures
    try:
        cats = list(d_fit_best["SDI_Q"].cat.categories)
    except Exception:
        cats = None

    def slope_by_group(res, x, cats, center_inter=False):
        b = res.params
        V = res.cov_params()
        x_int = f"{x}_CQ" if center_inter else x

        out = []
        g0 = cats[0]
        b0 = float(b.get(x, np.nan))
        v0 = float(V.loc[x, x]) if (x in V.index) else np.nan
        se0 = float(np.sqrt(max(v0, 0.0))) if np.isfinite(v0) else np.nan
        out.append((g0, b0, se0))

        for g in cats[1:]:
            term = f"{x_int}:C(SDI_Q)[T.{g}]"
            if (x in b.index) and (term in b.index):
                bg = float(b[x] + b[term])
                var = float(V.loc[x, x] + V.loc[term, term] + 2.0*V.loc[x, term])
                seg = float(np.sqrt(max(var, 0.0)))
            else:
                bg, seg = np.nan, np.nan
            out.append((g, bg, seg))

        df_s = pd.DataFrame(out, columns=["SDI_Q","slope","se"])
        df_s["ci_lo"] = df_s["slope"] - 1.96*df_s["se"]
        df_s["ci_hi"] = df_s["slope"] + 1.96*df_s["se"]
        return df_s

    params = set(res_best.params.index.tolist())
    interacted = []
    if cats is not None:
        for x in exposures:
            x_int = f"{x}_CQ" if center_inter else x
            has_any = any((f"{x_int}:C(SDI_Q)[T.{g}]" in params) for g in cats[1:])
            if has_any:
                interacted.append(x)

    if cats is not None and interacted:
        for x in interacted:
            df_s = slope_by_group(res_best, x, cats, center_inter=center_inter)
            out_csv = os.path.join(out_dir, f"BEST_slope_by_SDI__{x}.csv")
            df_s.to_csv(out_csv, index=False, encoding="utf-8-sig")

            plt.figure(figsize=(7.2, 4.6))
            xx = np.arange(1, len(cats)+1)
            plt.plot(xx, df_s["slope"].values, marker="o", lw=2.2)
            plt.fill_between(xx, df_s["ci_lo"].values, df_s["ci_hi"].values, alpha=0.18)
            plt.axhline(0, lw=1, color="k", alpha=0.6)
            plt.xticks(xx, df_s["SDI_Q"].astype(str).values)
            plt.xlabel("SDI group")
            plt.ylabel(f"Slope of {x} on log(mean LUI rate)")
            plt.title(f"BEST model: {x} slope across SDI groups")
            plt.tight_layout()

            out_png = os.path.join(out_dir, f"BEST_slope_{x}_by_SDI.png")
            out_pdf = os.path.join(out_dir, f"BEST_slope_{x}_by_SDI.pdf")
            plt.savefig(out_png, dpi=300, bbox_inches="tight")
            plt.savefig(out_pdf, bbox_inches="tight")
            plt.close()

    return best, df_res


# =========================
# MAIN: LOAD PANEL ONCE, RUN ALL AGES
# =========================
def main():
    print("Loading panel:", IN_FP)
    df0 = pd.read_csv(IN_FP)

    df0["iso3"] = df0["iso3"].astype(str).str.upper().str.strip()
    df0["year"] = pd.to_numeric(df0["year"], errors="coerce")
    df0 = df0[(df0["year"] >= Y0) & (df0["year"] <= Y1)].copy()

    best_rows = []

    for age, spec in AGE_SPECS.items():
        y_col = spec["y"]
        pop_col = spec["pop"]

        out_dir = os.path.join(OUT_ROOT, f"AGE_{age}")
        os.makedirs(out_dir, exist_ok=True)

        print("\n" + "="*90)
        print(f"üöÄ AGE={age} | y={y_col} | pop={pop_col}")
        print("="*90)

        need = [
            "iso3","year", y_col, pop_col,
            HAP_COL, PM25_COL, TAVG_COL, AH_COL,
            SDI_COL, DENS_COL
        ]
        require_cols(df0, need)

        df = df0.dropna(subset=need).copy()
        df[y_col] = pd.to_numeric(df[y_col], errors="coerce")
        df[pop_col] = pd.to_numeric(df[pop_col], errors="coerce")

        df = df.dropna(subset=[y_col, pop_col]).copy()
        df["rate"] = df[y_col] / df[pop_col].clip(lower=1.0) * 100_000.0

        df_mean = (
            df.groupby("iso3", as_index=False)
              .agg(
                  lui_rate_mean=("rate","mean"),
                  hap_mean=(HAP_COL,"mean"),
                  pm25_mean=(PM25_COL,"mean"),
                  tavg_mean=(TAVG_COL,"mean"),
                  ah_mean=(AH_COL,"mean"),
                  sdi_mean=(SDI_COL,"mean"),
                  dens_mean=(DENS_COL,"mean"),
                  n_years=("year","nunique"),
              )
        )
        df_mean["log_lui_rate_mean"] = safe_log(df_mean["lui_rate_mean"], floor=1e-6)

        # Ë∑ëÁ©∑‰∏æ
        best, _ = run_exhaust_one_age(df_mean, out_dir, vif_cap=VIF_CAP, verbose=True)

        best_row = best.copy()
        best_row["age"] = age
        best_row["y_col"] = y_col
        best_row["pop_col"] = pop_col
        best_rows.append(best_row)

    # Ê±áÊÄª
    df_best = pd.DataFrame(best_rows)
    out_fp = os.path.join(OUT_ROOT, "BEST_MODELS_ALL_AGES.csv")
    df_best.to_csv(out_fp, index=False, encoding="utf-8-sig")

    print("\nüéØ DONE. Summary saved:")
    print(out_fp)
    show_cols = ["age","formula","CV_RMSE","BIC","AdjR2","WaldP_inter","maxVIF","interactions",
                 "coding","center_inter","clim_struct","hap","pm25","zall","sdi_scheme",
                 "main4_all_sig","p_HAP_main","p_PM25_main","p_CLIM1_main","p_CLIM2_main"]
    show_cols = [c for c in show_cols if c in df_best.columns]
    print(df_best[show_cols])

if __name__ == "__main__":
    main()

 ## model used in this issue

In [None]:
# -*- coding: utf-8 -*-
import os
import warnings
import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt

import statsmodels.formula.api as smf
from patsy import dmatrix, build_design_matrices

import geopandas as gpd
import cartopy.crs as ccrs
from matplotlib.colors import Normalize

warnings.filterwarnings("ignore")

# =========================
# STYLE
# =========================
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
plt.rcParams["font.sans-serif"] = ["Arial", "SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# =========================
# PATHS
# =========================
IN_FP  = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\panel0_1990_2019_direct_meteo_GBDPM_HAP_lui.csv"
SHP_FP = r"D:\AAUDE\paper_v2\paper2\data\ne_10m_admin_0_countries\ne_10m_admin_0_countries.shp"
OUT_ROOT = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\FIG_AIR_2x2_ALL_AGES_IMPACT_PCT_Q5PLOT_CI_HATCH_v2"
os.makedirs(OUT_ROOT, exist_ok=True)

# =========================
# GLOBAL CONFIG
# =========================
Y0, Y1 = 1990, 2019
MAP_YEAR = 2019
RATE_PER = 100_000.0

HAP_RAW  = "hap_pm_pw"
PM25_RAW = "pm25_pw"
TAVG_RAW = "tavg_pw_C"
AH_RAW   = "ah_pw"
DENS_RAW = "density_total_pkm2"
SDI_RAW  = "sdi"

AUTO_SCALE_DIV10 = False

TMREL_HAP_BASE  = 2.4
TMREL_PM25_BASE = 2.4

# SDI colors (Q1..Q5)
SDI_LABELS_Q5 = ["Q1","Q2","Q3","Q4","Q5"]
SDI_HEX_Q5 = [ "#547BB4", "#DD7C4F","#629C35", "#C0321A", "#6C61AF"]

# ---- CI simulation ----
SIM_N = 3000
SIM_SEED = 123
CI_LO, CI_HI = 2.5, 97.5

# =========================
# AGE-SPECIFIC BEST SETTINGS
# =========================
AGE_CFG = {
    "total": dict(
        y="uri_total", pop="pop_total",
        hap="hap_log1p",
        pm25="pm25_raw",
        sdi_scheme="Q4_mergeTop",
        clim_struct="both_TAVGresid_on_AH",
        center_inter="nocenter",
        rhs="HAP_X + PM25_X + PM25_X:C(SDI_Q) + CLIM1_X + CLIM2_X + DENS_X",
        title_tag="AGE_total"
    ),
    "u5": dict(
        y="uri_u5", pop="pop_u5",
        hap="hap_log1p",
        pm25="pm25_log1p",
        sdi_scheme="Q4_mergeTop",
        clim_struct="both_TAVGresid_on_AH",
        center_inter="nocenter",
        rhs="HAP_X + PM25_X * C(SDI_Q) + CLIM1_X * C(SDI_Q) + CLIM2_X * C(SDI_Q) + DENS_X",
        title_tag="AGE_u5"
    ),
    "5_65": dict(
        y="uri_5_65", pop="pop_5_65",
        hap="hap_raw",
        pm25="pm25_raw",
        sdi_scheme="Q5",
        clim_struct="both_TAVGresid_on_AH",
        center_inter="nocenter",
        rhs="HAP_X + HAP_X:C(SDI_Q) + PM25_X + CLIM1_X + CLIM2_X + DENS_X",
        title_tag="AGE_5_65"
    ),
    "65p": dict(
        y="uri_65p", pop="pop_65p",
        hap="hap_log1p",
        pm25="pm25_raw",
        sdi_scheme="Q4_mergeTop",
        clim_struct="both_TAVGresid_on_AH",
        center_inter="nocenter",
        rhs="HAP_X + PM25_X + PM25_X:C(SDI_Q) + CLIM1_X + CLIM2_X + DENS_X",
        title_tag="AGE_65p"
    ),
}

# =========================
# helpers
# =========================
def require_cols(df, cols):
    miss = [c for c in cols if c not in df.columns]
    if miss:
        raise ValueError(f"Missing columns: {miss}")

def safe_log(x, floor=1e-12):
    return np.log(np.clip(x, floor, None))

def safe_log1p(x):
    return np.log1p(np.clip(x, 0.0, None))

def load_world(shp_fp):
    g = gpd.read_file(shp_fp)
    cand = ["ADM0_A3", "ISO_A3", "SOV_A3", "WB_A3", "ISO3", "iso3"]
    key = None
    for c in cand:
        if c in g.columns:
            key = c
            break
    if key is None:
        raise ValueError(f"Shapefile lacks iso3. Available: {list(g.columns)}")
    g = g.rename(columns={key: "iso3"})
    g["iso3"] = g["iso3"].astype(str).str.upper().str.strip()
    g.loc[g["iso3"] == "-99", "iso3"] = np.nan
    g = g.dropna(subset=["iso3"]).copy()
    g = g[g["iso3"] != "ATA"].copy()
    return g

def build_iso3_means(df_panel, y_col, pop_col):
    d = df_panel.copy()
    d["rate"] = d[y_col] / d[pop_col].clip(lower=1.0) * RATE_PER
    df_mean = (
        d.groupby("iso3", as_index=False)
         .agg(
             lui_rate_mean=("rate","mean"),
             hap_mean=(HAP_RAW,"mean"),
             pm25_mean=(PM25_RAW,"mean"),
             tavg_mean=(TAVG_RAW,"mean"),
             ah_mean=(AH_RAW,"mean"),
             sdi_mean=(SDI_RAW,"mean"),
             dens_mean=(DENS_RAW,"mean"),
         )
    )
    df_mean["log_lui_rate_mean"] = safe_log(df_mean["lui_rate_mean"], floor=1e-6)
    return df_mean

def make_sdi_maps_from_means(df_mean, scheme_model="Q4_mergeTop"):
    s = df_mean.set_index("iso3")["sdi_mean"].astype(float)

    # plot Q5
    labels5 = [f"Q{i}" for i in range(1, 6)]
    q5 = pd.qcut(s, q=5, labels=labels5, duplicates="drop").astype(str)
    sdi_map_plot5 = q5.to_dict()
    labels_plot5 = labels5

    # model scheme
    if scheme_model == "Q5":
        sdi_map_model = sdi_map_plot5
        labels_model = labels_plot5
    elif scheme_model == "Q4_mergeTop":
        q4 = q5.replace({"Q4": "TOP", "Q5": "TOP"})
        sdi_map_model = q4.to_dict()
        labels_model = ["Q1", "Q2", "Q3", "TOP"]
    else:
        raise ValueError(f"Unknown scheme_model: {scheme_model}")

    return sdi_map_model, labels_model, sdi_map_plot5, labels_plot5

def fit_resid_params_on_means(df_mean, clim_struct):
    if clim_struct == "both_raw":
        return None, None, "RAW"
    if clim_struct == "both_AHresid_on_TAVG":
        tmp = df_mean.dropna(subset=["ah_mean","tavg_mean"]).copy()
        m = smf.ols("ah_mean ~ tavg_mean", data=tmp).fit()
        return float(m.params["Intercept"]), float(m.params["tavg_mean"]), "AH_on_TAVG"
    if clim_struct == "both_TAVGresid_on_AH":
        tmp = df_mean.dropna(subset=["ah_mean","tavg_mean"]).copy()
        m = smf.ols("tavg_mean ~ ah_mean", data=tmp).fit()
        return float(m.params["Intercept"]), float(m.params["ah_mean"]), "TAVG_on_AH"
    raise ValueError(f"Unknown clim_struct: {clim_struct}")

def transform_hap_from_raw(x_raw, hap_mode):
    x_raw = pd.to_numeric(x_raw, errors="coerce").astype(float)
    if hap_mode == "hap_log1p":
        return safe_log1p(x_raw)
    if hap_mode == "hap_raw":
        return x_raw
    raise ValueError(f"Unknown hap mode: {hap_mode}")

def transform_pm25_from_raw(x_raw, pm_mode):
    x_raw = pd.to_numeric(x_raw, errors="coerce").astype(float)
    if pm_mode == "pm25_raw":
        return x_raw
    if pm_mode == "pm25_log1p":
        return safe_log1p(x_raw)
    raise ValueError(f"Unknown pm25 mode: {pm_mode}")

def apply_climate_struct(d, resid_params, is_means=True):
    a0, b1, mode = resid_params
    out = d.copy()
    if is_means:
        t = pd.to_numeric(out["tavg_mean"], errors="coerce")
        a = pd.to_numeric(out["ah_mean"], errors="coerce")
    else:
        t = pd.to_numeric(out["TAVG"], errors="coerce")
        a = pd.to_numeric(out["AH"], errors="coerce")

    if mode == "RAW":
        out["CLIM1_X"] = t
        out["CLIM2_X"] = a
    elif mode == "AH_on_TAVG":
        out["CLIM1_X"] = t
        out["CLIM2_X"] = a - (a0 + b1 * t)
    elif mode == "TAVG_on_AH":
        out["CLIM1_X"] = t - (a0 + b1 * a)
        out["CLIM2_X"] = a
    else:
        raise ValueError(mode)
    return out

def interacted_exposures_from_rhs(rhs):
    expos = ["HAP_X", "PM25_X", "CLIM1_X", "CLIM2_X"]
    inter = set()
    for x in expos:
        if (f"{x}:C(SDI_Q)" in rhs) or (f"{x}_CQ:C(SDI_Q)" in rhs) or (f"{x}*C(SDI_Q)" in rhs) or (f"{x} * C(SDI_Q)" in rhs):
            inter.add(x)
    return inter

def add_center_within_q(d, rhs, group_col="SDI_Q"):
    out = d.copy()
    inter = interacted_exposures_from_rhs(rhs)
    for x in sorted(inter):
        mu = out.groupby(group_col)[x].transform("mean")
        out[f"{x}_CQ"] = out[x] - mu
    return out

def build_design(df, rhs, design_info=None):
    if design_info is None:
        X = dmatrix("1 + " + rhs, data=df, return_type="dataframe")
        return X, X.design_info
    mats = build_design_matrices([design_info], df, return_type="dataframe")
    return mats[0], design_info

def predict_rate(beta, X):
    eta = np.asarray(X.values @ beta, float)
    return np.exp(eta)  # rate per 100k

# =========================
# CI via Monte Carlo (PSD clip)
# =========================
def draw_betas_psd(beta_hat, cov, n=1000, seed=123, eig_floor=1e-8):
    rng = np.random.default_rng(seed)
    cov = np.asarray(cov, float)
    cov = 0.5 * (cov + cov.T)

    w, V = np.linalg.eigh(cov)
    w = np.clip(w, eig_floor, None)
    cov_psd = (V * w) @ V.T

    return rng.multivariate_normal(mean=np.asarray(beta_hat, float), cov=cov_psd, size=n)

def agg_pct_by_codes(ro, rc, codes, n_codes):
    ro_sum = np.bincount(codes, weights=ro, minlength=n_codes).astype(float)
    dr_sum = np.bincount(codes, weights=(ro - rc), minlength=n_codes).astype(float)
    denom = np.maximum(ro_sum, 1e-12)
    return dr_sum / denom * 100.0

def impact_map_with_ci(beta_hat, cov, Xo, Xc, iso3_series, n_draw=1000, seed=123):
    iso3 = iso3_series.astype(str).values
    uniq = np.unique(iso3)
    code_map = {k:i for i,k in enumerate(uniq)}
    codes = np.array([code_map[k] for k in iso3], dtype=int)
    m = len(uniq)

    ro0 = predict_rate(beta_hat, Xo)
    rc0 = predict_rate(beta_hat, Xc)
    imp0 = agg_pct_by_codes(ro0, rc0, codes, m)

    B = draw_betas_psd(beta_hat, cov, n=n_draw, seed=seed)
    sims = np.empty((n_draw, m), dtype=float)
    XoV = Xo.values
    XcV = Xc.values

    for i in range(n_draw):
        b = B[i]
        ro = np.exp(XoV @ b)
        rc = np.exp(XcV @ b)
        sims[i, :] = agg_pct_by_codes(ro, rc, codes, m)

    lo = np.nanpercentile(sims, CI_LO, axis=0)
    hi = np.nanpercentile(sims, CI_HI, axis=0)

    sig = ~((lo <= 0.0) & (hi >= 0.0))

    return pd.DataFrame({
        "iso3": uniq,
        "Impact": imp0,
        "CI_lo": lo,
        "CI_hi": hi,
        "sig95": sig.astype(int),
    })

def impact_ts_by_group_with_ci(beta_hat, cov, Xo, Xc, years, group, group_labels, n_draw=1000, seed=123):
    years = years.astype(int).values
    g = group.astype(str).values

    y_uniq = np.unique(years)
    g_uniq = np.array(group_labels, dtype=str)

    y_map = {yy:i for i,yy in enumerate(y_uniq)}
    g_map = {gg:i for i,gg in enumerate(g_uniq)}

    # ËøáÊª§Êéâ group ‰∏çÂú® labels ÈáåÁöÑÔºà‰øùÈô©Ôºâ
    keep = np.array([gg in g_map for gg in g], dtype=bool)
    years2 = years[keep]
    g2 = g[keep]
    XoV = Xo.values[keep, :]
    XcV = Xc.values[keep, :]

    y_code = np.array([y_map[yy] for yy in years2], dtype=int)
    g_code = np.array([g_map[gg] for gg in g2], dtype=int)

    nY = len(y_uniq)
    nG = len(g_uniq)
    key = y_code * nG + g_code
    nK = nY * nG

    ro0 = np.exp(XoV @ beta_hat)
    rc0 = np.exp(XcV @ beta_hat)
    imp0 = agg_pct_by_codes(ro0, rc0, key, nK).reshape(nY, nG)

    B = draw_betas_psd(beta_hat, cov, n=n_draw, seed=seed)
    sims = np.empty((n_draw, nY, nG), dtype=float)

    for i in range(n_draw):
        b = B[i]
        ro = np.exp(XoV @ b)
        rc = np.exp(XcV @ b)
        sims[i, :, :] = agg_pct_by_codes(ro, rc, key, nK).reshape(nY, nG)

    lo = np.nanpercentile(sims, CI_LO, axis=0)
    hi = np.nanpercentile(sims, CI_HI, axis=0)

    out = []
    for yi, yy in enumerate(y_uniq):
        for gi, gg in enumerate(g_uniq):
            out.append((int(yy), str(gg),
                        float(imp0[yi, gi]), float(lo[yi, gi]), float(hi[yi, gi])))
    df = pd.DataFrame(out, columns=["year","SDI_Q","Impact","CI_lo","CI_hi"])
    df["sig95"] = (~((df["CI_lo"] <= 0.0) & (df["CI_hi"] >= 0.0))).astype(int)
    return df.sort_values(["SDI_Q","year"])

# =========================
# Plotting
# =========================
def plot_map_impact_with_hatch(ax, world, df_imp, title):
    g = world.merge(df_imp, on="iso3", how="left")

    vals = pd.to_numeric(g["Impact"], errors="coerce").values
    ok = np.isfinite(vals)
    if ok.sum() == 0:
        raise ValueError("No finite Impact values.")

    v2 = float(np.nanpercentile(vals[ok], 2))
    v98 = float(np.nanpercentile(vals[ok], 98))
    vmin = max(0.0, v2)
    vmax = max(v98, vmin + 1e-6)
    norm = Normalize(vmin=vmin, vmax=vmax)

    ax.set_title(title, loc="left", fontsize=12)
    ax.set_global()
    ax.set_facecolor("white")

    g.plot(
        column="Impact",
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap="Reds",
        norm=norm,
        edgecolor="white",
        linewidth=0.35,
        missing_kwds=dict(color="#f2f2f2", edgecolor="white", linewidth=0.25, hatch=".."),
        zorder=2,
    )

    ax.set_extent([-180, 180, -58, 85], crs=ccrs.PlateCarree())
    return norm

def plot_impact_lines_ci(ax, df_ts, title, sdi_labels, sdi_hex, ylabel):
    ax.set_title(title, loc="left", fontsize=12)
    if df_ts is None or len(df_ts) == 0:
        ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
        return

    d = df_ts.copy()
    d["SDI_Q"] = d["SDI_Q"].astype(str)

    for i, lab in enumerate(sdi_labels):
        s = d[d["SDI_Q"] == lab].sort_values("year")
        if s.empty:
            continue
        c = sdi_hex[i] if sdi_hex else None

        ax.plot(s["year"], s["Impact"], lw=2.6, color=c, label=lab)
        ax.fill_between(s["year"].values, s["CI_lo"].values, s["CI_hi"].values, alpha=0.18, color=c)

        ss = s.iloc[-1]
        ax.text(int(ss["year"]) + 0.2, float(ss["Impact"]),
                f"{float(ss['Impact']):.2f}%",
                fontsize=9, va="center", color=c)

    ax.set_xlabel("Year")
    ax.set_ylabel(ylabel)
    ax.legend(frameon=False, ncol=len(sdi_labels), fontsize=9)

def impact_unit():
    return "Attributable impact (%)"

# =========================
# Data prep
# =========================
def prep_means_dataset(df_mean, sdi_map_model, labels_model, sdi_map_plot5, labels_plot5, cfg, resid_params):
    d = df_mean.copy()

    # SDI for model (Q4/Q5)
    d["SDI_Q"] = d["iso3"].map(sdi_map_model)
    d = d.dropna(subset=["SDI_Q"]).copy()
    d["SDI_Q"] = pd.Categorical(d["SDI_Q"], categories=labels_model, ordered=True)

    # SDI for plotting (Q5)
    d["SDI_Q5"] = d["iso3"].map(sdi_map_plot5)
    d["SDI_Q5"] = pd.Categorical(d["SDI_Q5"], categories=labels_plot5, ordered=True)

    d["HAP_X"]  = transform_hap_from_raw(d["hap_mean"], cfg["hap"])
    d["PM25_X"] = transform_pm25_from_raw(d["pm25_mean"], cfg["pm25"])
    d = apply_climate_struct(d, resid_params, is_means=True)
    d["DENS_X"] = safe_log1p(pd.to_numeric(d["dens_mean"], errors="coerce"))

    if cfg["center_inter"] == "centerWithinQ":
        d = add_center_within_q(d, cfg["rhs"], group_col="SDI_Q")

    need = ["log_lui_rate_mean","SDI_Q","HAP_X","PM25_X","CLIM1_X","CLIM2_X","DENS_X"]
    if cfg["center_inter"] == "centerWithinQ":
        for x in ["HAP_X","PM25_X","CLIM1_X","CLIM2_X"]:
            if f"{x}_CQ" in cfg["rhs"]:
                need.append(f"{x}_CQ")

    d = d.dropna(subset=need).copy()
    return d

def prep_panel_dataset(df_panel, sdi_map_model, labels_model, sdi_map_plot5, labels_plot5, cfg, resid_params):
    d = df_panel.copy()

    d["SDI_Q"] = d["iso3"].map(sdi_map_model)
    d = d.dropna(subset=["SDI_Q"]).copy()
    d["SDI_Q"] = pd.Categorical(d["SDI_Q"], categories=labels_model, ordered=True)

    d["SDI_Q5"] = d["iso3"].map(sdi_map_plot5)
    d["SDI_Q5"] = pd.Categorical(d["SDI_Q5"], categories=labels_plot5, ordered=True)

    d["HAP_raw"]  = pd.to_numeric(d[HAP_RAW], errors="coerce")
    d["PM25_raw"] = pd.to_numeric(d[PM25_RAW], errors="coerce")
    d["TAVG"]     = pd.to_numeric(d[TAVG_RAW], errors="coerce")
    d["AH"]       = pd.to_numeric(d[AH_RAW], errors="coerce")
    d["DENS"]     = pd.to_numeric(d[DENS_RAW], errors="coerce")

    d["HAP_X"]  = transform_hap_from_raw(d["HAP_raw"], cfg["hap"])
    d["PM25_X"] = transform_pm25_from_raw(d["PM25_raw"], cfg["pm25"])

    d = apply_climate_struct(d, resid_params, is_means=False)
    d["DENS_X"] = safe_log1p(d["DENS"])

    if cfg["center_inter"] == "centerWithinQ":
        d = add_center_within_q(d, cfg["rhs"], group_col="SDI_Q")

    need = ["iso3","year","SDI_Q","SDI_Q5","HAP_raw","PM25_raw","HAP_X","PM25_X",
            "CLIM1_X","CLIM2_X","DENS_X", cfg["pop"], cfg["y"]]
    if cfg["center_inter"] == "centerWithinQ":
        for x in ["HAP_X","PM25_X","CLIM1_X","CLIM2_X"]:
            if f"{x}_CQ" in cfg["rhs"]:
                need.append(f"{x}_CQ")

    d = d.dropna(subset=need).copy()
    d["year"] = pd.to_numeric(d["year"], errors="coerce").astype(int)
    d = d[(d["year"] >= Y0) & (d["year"] <= Y1)].copy()
    return d

# =========================
# runner
# =========================
def run_one_age(df0_age, world, age_key, cfg, tmrel_hap, tmrel_pm25):
    out_dir = os.path.join(OUT_ROOT, cfg["title_tag"])
    os.makedirs(out_dir, exist_ok=True)

    # means + SDI maps
    df_mean = build_iso3_means(df0_age, cfg["y"], cfg["pop"])
    sdi_map_model, labels_model, sdi_map_plot5, labels_plot5 = make_sdi_maps_from_means(
        df_mean, scheme_model=cfg["sdi_scheme"]
    )
    resid_params = fit_resid_params_on_means(df_mean, cfg["clim_struct"])

    # fit on means
    dfit = prep_means_dataset(df_mean, sdi_map_model, labels_model, sdi_map_plot5, labels_plot5, cfg, resid_params)

    formula = "log_lui_rate_mean ~ " + cfg["rhs"]
    res = smf.ols(formula, data=dfit).fit(cov_type="HC1")

    beta_hat = res.params.values
    cov_hat  = res.cov_params().values

    # save summary
    with open(os.path.join(out_dir, "BEST_model_summary_used.txt"), "w", encoding="utf-8") as f:
        f.write("AGE = " + age_key + "\n")
        f.write(f"TMREL_HAP={tmrel_hap} | TMREL_PM25={tmrel_pm25}\n")
        f.write("FORMULA:\n" + formula + "\n\n")
        f.write(res.summary().as_text())

    coef = pd.DataFrame({
        "term": res.params.index,
        "beta": res.params.values,
        "se_HC1": res.bse.values,
        "p_HC1": res.pvalues.values
    }).sort_values("p_HC1")
    coef.to_csv(os.path.join(out_dir, "BEST_model_coef_used.csv"), index=False, encoding="utf-8-sig")

    # panel prep
    dpanel = prep_panel_dataset(df0_age, sdi_map_model, labels_model, sdi_map_plot5, labels_plot5, cfg, resid_params)
    obs = dpanel.copy()

    # ---- build obs design ----
    Xo, di = build_design(obs, cfg["rhs"], None)

    # ---- counterfactual HAP ----
    cf_h = obs.copy()
    cf_h_raw = np.minimum(cf_h["HAP_raw"].values.astype(float), float(tmrel_hap))
    cf_h["HAP_X"] = transform_hap_from_raw(cf_h_raw, cfg["hap"])
    if cfg["center_inter"] == "centerWithinQ":
        cf_h = add_center_within_q(cf_h, cfg["rhs"], group_col="SDI_Q")
    Xh, _ = build_design(cf_h, cfg["rhs"], di)

    # ---- counterfactual PM25 ----
    cf_p = obs.copy()
    cf_p_raw = np.minimum(cf_p["PM25_raw"].values.astype(float), float(tmrel_pm25))
    cf_p["PM25_X"] = transform_pm25_from_raw(cf_p_raw, cfg["pm25"])
    if cfg["center_inter"] == "centerWithinQ":
        cf_p = add_center_within_q(cf_p, cfg["rhs"], group_col="SDI_Q")
    Xp, _ = build_design(cf_p, cfg["rhs"], di)

    # =========================
    # MAP 2019: Impact + CI + sig
    # =========================
    d2019 = obs[obs["year"] == MAP_YEAR].copy()
    if d2019.empty:
        print(f"[SKIP MAP] {age_key} no {MAP_YEAR} rows.")
        return

    d2019_o = d2019.copy()
    d2019_h = d2019.copy()
    d2019_p = d2019.copy()

    d2019_h_raw = np.minimum(d2019_h["HAP_raw"].values.astype(float), float(tmrel_hap))
    d2019_h["HAP_X"] = transform_hap_from_raw(d2019_h_raw, cfg["hap"])

    d2019_p_raw = np.minimum(d2019_p["PM25_raw"].values.astype(float), float(tmrel_pm25))
    d2019_p["PM25_X"] = transform_pm25_from_raw(d2019_p_raw, cfg["pm25"])

    if cfg["center_inter"] == "centerWithinQ":
        d2019_o = add_center_within_q(d2019_o, cfg["rhs"], group_col="SDI_Q")
        d2019_h = add_center_within_q(d2019_h, cfg["rhs"], group_col="SDI_Q")
        d2019_p = add_center_within_q(d2019_p, cfg["rhs"], group_col="SDI_Q")

    X2019_o, _ = build_design(d2019_o, cfg["rhs"], di)
    X2019_h, _ = build_design(d2019_h, cfg["rhs"], di)
    X2019_p, _ = build_design(d2019_p, cfg["rhs"], di)

    imp_hap_map_ci = impact_map_with_ci(beta_hat, cov_hat, X2019_o, X2019_h, d2019_o["iso3"],
                                        n_draw=SIM_N, seed=SIM_SEED)
    imp_pm_map_ci  = impact_map_with_ci(beta_hat, cov_hat, X2019_o, X2019_p, d2019_o["iso3"],
                                        n_draw=SIM_N, seed=SIM_SEED + 1)

    imp_hap_map_ci.to_csv(os.path.join(out_dir, f"IMPACTMAP_HAP_{MAP_YEAR}_pct_{age_key}_CI.csv"),
                          index=False, encoding="utf-8-sig")
    imp_pm_map_ci.to_csv(os.path.join(out_dir, f"IMPACTMAP_PM25_{MAP_YEAR}_pct_{age_key}_CI.csv"),
                         index=False, encoding="utf-8-sig")

    # =========================
    # TS: Êåâ Q5 Â±ïÁ§∫Ôºà‰Ω†ÊÉ≥Ë¶ÅÁöÑÔºâ
    # =========================
    ts_hap_ci = impact_ts_by_group_with_ci(beta_hat, cov_hat, Xo, Xh,
                                          years=obs["year"], group=obs["SDI_Q5"],
                                          group_labels=SDI_LABELS_Q5,
                                          n_draw=SIM_N, seed=SIM_SEED)
    ts_pm_ci  = impact_ts_by_group_with_ci(beta_hat, cov_hat, Xo, Xp,
                                          years=obs["year"], group=obs["SDI_Q5"],
                                          group_labels=SDI_LABELS_Q5,
                                          n_draw=SIM_N, seed=SIM_SEED + 1)

    ts_hap_ci.to_csv(os.path.join(out_dir, f"IMPACTTS_HAP_pct_byQ5_{age_key}_CI.csv"),
                     index=False, encoding="utf-8-sig")
    ts_pm_ci.to_csv(os.path.join(out_dir, f"IMPACTTS_PM25_pct_byQ5_{age_key}_CI.csv"),
                    index=False, encoding="utf-8-sig")

    # =========================
    # Plot 2x2
    # =========================
    fig = plt.figure(figsize=(13.8, 7.6))
    gs = fig.add_gridspec(2, 2, wspace=0.06, hspace=0.22)

    ax1 = fig.add_subplot(gs[0,0], projection=ccrs.Robinson())
    ax2 = fig.add_subplot(gs[0,1], projection=ccrs.Robinson())
    ax3 = fig.add_subplot(gs[1,0])
    ax4 = fig.add_subplot(gs[1,1])

    unit_map = impact_unit()
    unit_ts  = impact_unit()

    n1 = plot_map_impact_with_hatch(
        ax1, world, imp_hap_map_ci,
        f"a | {age_key} impact (%) due to HAP ({MAP_YEAR})  (//: n.s.)"
    )
    n2 = plot_map_impact_with_hatch(
        ax2, world, imp_pm_map_ci,
        f"b | {age_key} impact (%) due to PM2.5 ({MAP_YEAR})  (//: n.s.)"
    )

    # colorbars
    smap1 = mpl.cm.ScalarMappable(norm=n1, cmap="Reds"); smap1.set_array([])
    cax1 = fig.add_axes([0.47, 0.56, 0.015, 0.32])
    cb1 = plt.colorbar(smap1, cax=cax1); cb1.set_label(unit_map)

    smap2 = mpl.cm.ScalarMappable(norm=n2, cmap="Reds"); smap2.set_array([])
    cax2 = fig.add_axes([0.92, 0.56, 0.015, 0.32])
    cb2 = plt.colorbar(smap2, cax=cax2); cb2.set_label(unit_map)

    plot_impact_lines_ci(
        ax3, ts_hap_ci,
        f"c | {age_key} impact(t) (%) due to HAP by SDI (plot=Q5; model={cfg['sdi_scheme']})",
        SDI_LABELS_Q5, SDI_HEX_Q5, ylabel=unit_ts
    )
    plot_impact_lines_ci(
        ax4, ts_pm_ci,
        f"d | {age_key} impact(t) (%) due to PM2.5 by SDI (plot=Q5; model={cfg['sdi_scheme']})",
        SDI_LABELS_Q5, SDI_HEX_Q5, ylabel=unit_ts
    )

    out_png = os.path.join(out_dir, f"FIG_air_2x2_IMPACTpct_{age_key}_Q5plot_CI_HATCH.png")
    out_pdf = os.path.join(out_dir, f"FIG_air_2x2_IMPACTpct_{age_key}_Q5plot_CI_HATCH.pdf")
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.savefig(out_pdf, bbox_inches="tight")
    plt.show()
    plt.close(fig)

    # ---- debug: 2019 ÊúâÂ§öÂ∞ëÂõΩÂÆ∂ PM25_raw <= TMRELÔºàÂØºËá¥ cf‚âàobsÔºåimpact‚âà0Ôºâ----
    d2019_debug = dpanel[dpanel["year"] == MAP_YEAR].copy()
    if not d2019_debug.empty:
        iso_m = d2019_debug.groupby("iso3")["PM25_raw"].mean()
        share_small = float((iso_m <= tmrel_pm25).mean())
        print(f"[DEBUG] {age_key} 2019 share( mean PM25_raw <= TMREL_PM25 ) = {share_small:.3f}")

    print(f"[OK] age={age_key} -> {out_dir}")

# =========================
# MAIN
# =========================
def main():
    # TMREL ‰∏éÂèòÈáèÂ∞∫Â∫¶‰∏ÄËá¥
    if AUTO_SCALE_DIV10:
        tmrel_hap  = TMREL_HAP_BASE / 10.0
        tmrel_pm25 = TMREL_PM25_BASE / 10.0
    else:
        tmrel_hap  = TMREL_HAP_BASE
        tmrel_pm25 = TMREL_PM25_BASE

    print(f"[TMREL] AUTO_SCALE_DIV10={AUTO_SCALE_DIV10} | TMREL_HAP={tmrel_hap} | TMREL_PM25={tmrel_pm25}")

    df0 = pd.read_csv(IN_FP)
    df0["iso3"] = df0["iso3"].astype(str).str.upper().str.strip()
    df0["year"] = pd.to_numeric(df0["year"], errors="coerce")

    need = ["iso3","year",HAP_RAW,PM25_RAW,TAVG_RAW,AH_RAW,DENS_RAW,SDI_RAW]
    for _, cfg in AGE_CFG.items():
        need += [cfg["y"], cfg["pop"]]
    need = sorted(set(need))
    require_cols(df0, need)

    for c in need:
        if c != "iso3":
            df0[c] = pd.to_numeric(df0[c], errors="coerce")

    df0 = df0.dropna(subset=["iso3","year"]).copy()
    df0 = df0[(df0["year"]>=Y0) & (df0["year"]<=Y1)].copy()
    df0["year"] = df0["year"].astype(int)

    # Âø´ÈÄüÊ£ÄÊü•ÂèòÈáèÂ∞∫Â∫¶ÔºàÂà§Êñ≠ /10 ÊòØÂê¶ÂêàÁêÜÔºâ
    print("[CHECK] PM25_raw median/p95:",
          float(df0[PM25_RAW].median()), float(df0[PM25_RAW].quantile(0.95)))
    print("[CHECK] HAP_raw  median/p95:",
          float(df0[HAP_RAW].median()), float(df0[HAP_RAW].quantile(0.95)))

    world = load_world(SHP_FP)

    for age_key, cfg in AGE_CFG.items():
        cols_age = [cfg["y"], cfg["pop"], HAP_RAW, PM25_RAW, TAVG_RAW, AH_RAW, DENS_RAW, SDI_RAW]
        d = df0.dropna(subset=cols_age).copy()
        if d.empty:
            print(f"[SKIP] age={age_key} empty after dropna.")
            continue

        print("\n" + "="*96)
        print(f"RUN age={age_key} | model SDI scheme={cfg['sdi_scheme']} | plot=Q5 | TMREL={tmrel_pm25}")
        print("="*96)

        run_one_age(d, world, age_key, cfg, tmrel_hap=tmrel_hap, tmrel_pm25=tmrel_pm25)

    print("\nDONE. OUT_ROOT =", OUT_ROOT)

if __name__ == "__main__":
    main()

## Robustness check using two-way fixed effects

In [None]:
# -*- coding: utf-8 -*-
"""

"""

import os
import warnings
import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt

import statsmodels.api as sm
import statsmodels.formula.api as smf

import geopandas as gpd
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.colors import TwoSlopeNorm

warnings.filterwarnings("ignore")

# =========================
# STYLE
# =========================
mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"]  = 42
plt.rcParams["font.sans-serif"] = ["Arial", "SimHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

# =========================
# PATHS
# =========================
IN_FP  = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\panel0_1990_2019_direct_meteo_GBDPM_HAP_lui.csv"
SHP_FP = r"D:\AAUDE\paper_v2\paper2\data\ne_10m_admin_0_countries\ne_10m_admin_0_countries.shp"
OUT_DIR = r"D:\AAUDE\paper_v2\paper2\data\model_outputs\AF_PAF_space_time_v1"
os.makedirs(OUT_DIR, exist_ok=True)

# =========================
# CONFIG: choose age group
# =========================
# ‰Ω†ËøôÈáå uri_total ÂÆûÈôÖÊòØ LUI_totalÔºàÂêçÂ≠óÊ≤°ÊîπÔºâ
Y_COL   = "uri_total"
POP_COL = "pop_total"

# exposures (raw)
HAP_COL  = "hap_pm_pw"
PM25_COL = "pm25_pw"

# controls (raw)
TAVG_COL = "tavg_pw_C"
AH_COL   = "ah_pw"
DENS_COL = "density_total_pkm2"
SDI_COL  = "sdi"

# years
Y0, Y1 = 1990, 2019
MAP_YEAR = 2019

# TMREL (counterfactual)
TMREL_HAP  = 5.0
TMREL_PM25 = 5.0

# model family
USE_NB = False        # True=NB(alpha fixed), False=Poisson
ALPHA_NB = 1.0        # only used if USE_NB=True

# SDI grouping (stable by country mean SDI)
SDI_Q = 5
SDI_HEX = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd"]  # Q1..Q5
# =========================
# helpers
# =========================
def require_cols(df, cols):
    miss = [c for c in cols if c not in df.columns]
    if miss:
        raise ValueError(f"Missing columns: {miss}")

def load_world_shp(shp_fp):
    g = gpd.read_file(shp_fp)
    cand = ["ADM0_A3", "ISO_A3", "SOV_A3", "WB_A3", "ISO3", "iso3"]
    key = None
    for c in cand:
        if c in g.columns:
            key = c
            break
    if key is None:
        raise ValueError(f"Shapefile lacks iso3 field. Tried: {cand}. Available: {list(g.columns)}")

    g = g.rename(columns={key: "iso3"})
    g["iso3"] = g["iso3"].astype(str).str.upper().str.strip()
    g.loc[g["iso3"] == "-99", "iso3"] = np.nan
    g = g[~g["iso3"].isna()].copy()

    # ‚úÖ drop Antarctica
    g = g[g["iso3"] != "ATA"].copy()

    return g
def plot_paf_map(world_gdf, paf_df, title, out_fp):
    g = world_gdf.merge(paf_df, on="iso3", how="left").copy()

    vals = pd.to_numeric(g["PAF"], errors="coerce").to_numpy()
    finite = np.isfinite(vals)
    if finite.sum() == 0:
        raise ValueError("No finite PAF values to plot.")

    # robust symmetric limits
    v2 = np.nanpercentile(vals[finite], 2)
    v98 = np.nanpercentile(vals[finite], 98)
    lim = float(max(abs(v2), abs(v98)))
    vmin, vmax = -lim, lim
    norm = TwoSlopeNorm(vmin=vmin, vcenter=0.0, vmax=vmax)

    fig = plt.figure(figsize=(12.2, 6.0))
    ax = plt.axes(projection=ccrs.Robinson())
    ax.set_global()

    # ‚úÖ no ocean background
    ax.set_facecolor("white")

    # ‚úÖ land base only (no ocean feature)
    ax.add_feature(cfeature.LAND, facecolor="#d9d9d9", zorder=0)
    ax.coastlines(linewidth=0.25, alpha=0.6)

    # plot countries
    g.plot(
        column="PAF",
        ax=ax,
        transform=ccrs.PlateCarree(),
        cmap="RdBu_r",
        norm=norm,
        edgecolor="#4c4c4c",
        linewidth=0.25,
        missing_kwds=dict(
            color="#f2f2f2", edgecolor="#4c4c4c", linewidth=0.2, hatch="..", label="Missing"
        ),
        zorder=2,
    )

    ax.set_title(title, fontsize=13)

    # colorbar
    smap = mpl.cm.ScalarMappable(norm=norm, cmap="RdBu_r")
    smap.set_array([])
    cbar = plt.colorbar(smap, ax=ax, fraction=0.03, pad=0.03)
    cbar.set_label("PAF (fraction)", fontsize=11)

    plt.tight_layout()
    plt.savefig(out_fp, dpi=300, bbox_inches="tight")
    plt.show()
    plt.close(fig)
    print("Saved:", out_fp)
def af_timeseries_by_sdiq_ci(res, df, x_col, x0, n_boot=300, seed=123):
    """
    Bootstrap CI for AF(t) by SDI quintiles.
    Resample iso3 with replacement; within each bootstrap, compute AF for each year√óSDI_Q.
    Returns: df with columns [year, SDI_Q, AF, ci_lo, ci_hi]
    """
    rng = np.random.default_rng(seed)
    iso_list = np.array(sorted(df["iso3"].unique()))
    years = np.array(sorted(df["year"].unique()))
    qs = np.arange(1, SDI_Q + 1)

    # point estimate
    df_point = af_timeseries_by_sdiq(res, df, x_col=x_col, x0=x0)
    df_point = df_point[["year","SDI_Q","AF"]].copy()

    # store boot AF
    boot_rec = []

    for b in range(n_boot):
        boot_iso = rng.choice(iso_list, size=len(iso_list), replace=True)
        # build boot sample by concatenating country blocks
        dboot = pd.concat([df[df["iso3"] == iso].copy() for iso in boot_iso], ignore_index=True)

        # compute AF
        dab = af_timeseries_by_sdiq(res, dboot, x_col=x_col, x0=x0)
        dab = dab[["year","SDI_Q","AF"]].copy()
        dab["b"] = b
        boot_rec.append(dab)

    dfb = pd.concat(boot_rec, ignore_index=True)

    # CI
    qlo, qhi = 0.025, 0.975
    ci = (
        dfb.groupby(["year","SDI_Q"])["AF"]
           .quantile([qlo, qhi])
           .unstack(level=-1)
           .reset_index()
           .rename(columns={qlo:"ci_lo", qhi:"ci_hi"})
    )

    out = df_point.merge(ci, on=["year","SDI_Q"], how="left")
    return out

def prepare_panel(df_raw):
    df = df_raw.copy()
    df["iso3"] = df["iso3"].astype(str).str.upper().str.strip()
    df["year"] = pd.to_numeric(df["year"], errors="coerce")

    need = ["iso3", "year", Y_COL, POP_COL, HAP_COL, PM25_COL, TAVG_COL, AH_COL, DENS_COL, SDI_COL]
    require_cols(df, need)

    # numeric
    for c in [Y_COL, POP_COL, HAP_COL, PM25_COL, TAVG_COL, AH_COL, DENS_COL, SDI_COL]:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    df = df.dropna(subset=need).copy()
    df = df[(df["year"] >= Y0) & (df["year"] <= Y1)].copy()
    df["year"] = df["year"].astype(int)

    # offset & transforms
    df["offset_log"] = np.log(df[POP_COL].clip(lower=1.0))
    df["log1p_density"] = np.log1p(df[DENS_COL].clip(lower=0.0))

    # stable SDI quintile by country mean SDI
    df["sdi_bar"] = df.groupby("iso3")[SDI_COL].transform("mean")
    sdi_country = df.groupby("iso3")["sdi_bar"].first()
    grp = pd.qcut(sdi_country, q=SDI_Q, labels=False, duplicates="drop") + 1
    df["SDI_Q"] = df["iso3"].map(grp).astype(int)

    return df

def fit_country_fe_model(df):
    formula = (
        f"{Y_COL} ~ year + C(iso3)"
        f" + {HAP_COL} + {PM25_COL}"
        f" + {TAVG_COL} + {AH_COL}"
        f" + log1p_density"
    )
    fam = sm.families.NegativeBinomial(alpha=ALPHA_NB) if USE_NB else sm.families.Poisson()
    res = smf.glm(
        formula=formula,
        data=df,
        family=fam,
        offset=df["offset_log"],
    ).fit(maxiter=200, disp=0)
    return res, formula

def predict_mu(res, d):
    return np.asarray(res.predict(d, offset=d["offset_log"]), dtype=float)

# ---- PAF for one year, one exposure
def paf_country_year(res, d_year, x_col, x0):
    out = []
    for iso, di in d_year.groupby("iso3", sort=False):
        di = di.copy()
        mu_obs = predict_mu(res, di)

        di_cf = di.copy()
        di_cf[x_col] = float(x0)
        mu_cf = predict_mu(res, di_cf)

        den = float(mu_obs.sum())
        paf = np.nan if den <= 0 else float((mu_obs.sum() - mu_cf.sum()) / den)
        out.append({"iso3": iso, "PAF": paf, "x0": float(x0)})
    return pd.DataFrame(out)

# ---- AF time series by SDI quintile for one exposure
def af_timeseries_by_sdiq(res, df, x_col, x0):
    rec = []
    for y in sorted(df["year"].unique()):
        dy = df[df["year"] == y].copy()
        if dy.empty:
            continue
        mu_obs = predict_mu(res, dy)

        dcf = dy.copy()
        dcf[x_col] = float(x0)
        mu_cf = predict_mu(res, dcf)

        for q in range(1, SDI_Q + 1):
            idx = (dy["SDI_Q"] == q).to_numpy()
            if idx.sum() == 0:
                continue
            sum_obs = float(mu_obs[idx].sum())
            delta = float((mu_obs[idx] - mu_cf[idx]).sum())
            af = float(delta / max(sum_obs, 1e-12))
            rec.append({
                "year": int(y),
                "SDI_Q": int(q),
                "AF": af,
                "DeltaCases": delta,
                "sum_mu_obs": sum_obs,
                "x_col": x_col,
                "x0": float(x0),
                "n_rows": int(idx.sum()),
            })
    return pd.DataFrame(rec)

def plot_af_lines(df_af, title, out_png, out_pdf):
    plt.figure(figsize=(8.8, 5.0))

    for q in range(1, SDI_Q + 1):
        s = df_af[df_af["SDI_Q"] == q].sort_values("year")
        if s.empty:
            continue

        plt.plot(s["year"], s["AF"], lw=2.6, color=SDI_HEX[q-1], label=f"Q{q}")

        # ‚úÖ CI shading if exists
        if ("ci_lo" in s.columns) and ("ci_hi" in s.columns):
            if np.isfinite(s["ci_lo"]).any() and np.isfinite(s["ci_hi"]).any():
                plt.fill_between(
                    s["year"],
                    s["ci_lo"],
                    s["ci_hi"],
                    color=SDI_HEX[q-1],
                    alpha=0.12,
                    linewidth=0
                )

    plt.axhline(0, lw=1, color="k", alpha=0.6)
    plt.xlabel("Year")
    plt.ylabel("Attributable fraction (AF)")
    plt.title(title)
    plt.legend(frameon=False, ncol=5)


    plt.tight_layout()
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.savefig(out_pdf, bbox_inches="tight")
    plt.show()
    plt.close()
    print("Saved:", out_png)
    print("Saved:", out_pdf)

# =========================
# MAIN
# =========================
def main():
    print("Loading:", IN_FP)
    df_raw = pd.read_csv(IN_FP)
    df = prepare_panel(df_raw)

    print("Prepared:", df.shape, "| iso3:", df["iso3"].nunique(), "| years:", df["year"].min(), "-", df["year"].max())
    print("SDI_Q counts:", df["SDI_Q"].value_counts().sort_index().to_dict())

    # fit model
    res, formula = fit_country_fe_model(df)
    print("\n[MODEL]\n", formula)
    print("[AIC]", float(res.aic))

    # save summary
    summ_fp = os.path.join(OUT_DIR, "model_summary.txt")
    with open(summ_fp, "w", encoding="utf-8") as f:
        f.write("FORMULA:\n" + formula + "\n\n")
        f.write(res.summary().as_text())
    print("Saved:", summ_fp)

    # ---------- (1) SPATIAL MAPS: PAF in MAP_YEAR ----------
    world = load_world_shp(SHP_FP)
    d_year = df[df["year"] == MAP_YEAR].copy()
    if d_year.empty:
        raise ValueError(f"No rows for MAP_YEAR={MAP_YEAR}")

    # HAP map
    paf_hap = paf_country_year(res, d_year, x_col=HAP_COL, x0=TMREL_HAP)
    paf_hap_fp = os.path.join(OUT_DIR, f"PAF_HAP_{MAP_YEAR}_details.csv")
    paf_hap.to_csv(paf_hap_fp, index=False, encoding="utf-8-sig")
    print("Saved:", paf_hap_fp)

    plot_paf_map(
        world, paf_hap,
        title=f"PAF_HAP ({MAP_YEAR}) | CF: {HAP_COL}={TMREL_HAP} | allow +/-",
        out_fp=os.path.join(OUT_DIR, f"PAF_HAP_{MAP_YEAR}.pdf")
    )

    # PM2.5 map
    paf_pm = paf_country_year(res, d_year, x_col=PM25_COL, x0=TMREL_PM25)
    paf_pm_fp = os.path.join(OUT_DIR, f"PAF_PM25_{MAP_YEAR}_details.csv")
    paf_pm.to_csv(paf_pm_fp, index=False, encoding="utf-8-sig")
    print("Saved:", paf_pm_fp)

    plot_paf_map(
        world, paf_pm,
        title=f"PAF_PM2.5 ({MAP_YEAR}) | CF: {PM25_COL}={TMREL_PM25} | allow +/-",
        out_fp=os.path.join(OUT_DIR, f"PAF_PM25_{MAP_YEAR}.pdf")
    )

    # ---------- (2) TEMPORAL: AF(t) by SDI quintiles ----------
    af_hap = af_timeseries_by_sdiq_ci(res, df, x_col=HAP_COL, x0=TMREL_HAP, n_boot=300, seed=123)
    af_hap_fp = os.path.join(OUT_DIR, "AF_timeseries_bySDIQ_HAP_withCI.csv")
    af_hap.to_csv(af_hap_fp, index=False, encoding="utf-8-sig")
    print("Saved:", af_hap_fp)
    
    plot_af_lines(
        af_hap,
        title=f"AF(t) by SDI quintiles | exposure=HAP | CF {HAP_COL}={TMREL_HAP}",
        out_png=os.path.join(OUT_DIR, "AF_bySDIQ_HAP_withCI.png"),
        out_pdf=os.path.join(OUT_DIR, "AF_bySDIQ_HAP_withCI.pdf"),
    )
    
    af_pm = af_timeseries_by_sdiq_ci(res, df, x_col=PM25_COL, x0=TMREL_PM25, n_boot=300, seed=123)
    af_pm_fp = os.path.join(OUT_DIR, "AF_timeseries_bySDIQ_PM25_withCI.csv")
    af_pm.to_csv(af_pm_fp, index=False, encoding="utf-8-sig")
    print("Saved:", af_pm_fp)
    
    plot_af_lines(
        af_pm,
        title=f"AF(t) by SDI quintiles | exposure=PM2.5 | CF {PM25_COL}={TMREL_PM25}",
        out_png=os.path.join(OUT_DIR, "AF_bySDIQ_PM25_withCI.png"),
        out_pdf=os.path.join(OUT_DIR, "AF_bySDIQ_PM25_withCI.pdf"),
    )


    print("\nDONE. OUT_DIR =", OUT_DIR)

if __name__ == "__main__":
    main()
