In [8]:
# ============================================================
# ModuleB: Onset ERP (<1000ms) FULL PIPELINE (FIXED ch_names)
# - Keep alignment epochs <-> trial_feat by KEYS
# - Attach labels from MASTER (robust)
# - Normalize EEG channel names: "EEG Fp1-Ref" -> "Fp1"
# - Keep only 19 EEG channels for ROI/topomap
# - ROI x window sign-flip permutation
# - ROI ERP plots + early inset + significant bands
# - Topomaps for early windows
# ============================================================

from pathlib import Path
import re
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt

# =========================
# 0) CONFIG（あなたのパス）
# =========================
EPOCHS_PATH = Path("/Users/shunsuke/EEG_48sounds/derivatives/epochs_all/epochs_all-epo.fif")
MASTER_PATH = Path("/Users/shunsuke/EEG_48sounds/derivatives/master_tables/master_participant_sound_level_with_PC.csv")
TRIAL_FEAT_PATH = Path("/Users/shunsuke/EEG_48sounds/moduleB_outputs/tables/moduleB_trial_eeg_features.csv")

OUT_DIR = Path("/Users/shunsuke/EEG_48sounds/moduleB_outputs/figs/ERP_onset_under1000ms_suite")
OUT_FIG = OUT_DIR / "figs"
OUT_TAB = OUT_DIR / "tables"
OUT_FIG.mkdir(parents=True, exist_ok=True)
OUT_TAB.mkdir(parents=True, exist_ok=True)

N_PERM = 9999
SEED = 0
SCALE_TO_uV = True

# 教授対応：0–1000ms（加えて0–220msを強調）
WINDOWS_MS = [(0,80),(80,140),(140,220),(220,350),(350,500),(500,800),(800,1000)]
EARLY_WINDOWS_MS = [(0,80),(80,140),(140,220)]
TOPO_WINDOWS_MS  = [(0,80),(80,140),(140,220)]

# ROI（19ch）
ROI_DEF = {
    "frontal":  ["Fp1","Fp2","F7","F3","Fz","F4","F8"],
    "central":  ["C3","Cz","C4"],
    "parietal": ["P7","P3","Pz","P4","P8"],
    "occipital":["O1","O2"],
}
EEG19 = sorted({ch for v in ROI_DEF.values() for ch in v})

# MASTERがPC列しかない場合の既定割当（必要なら変更）
PC_FALLBACK_MAP = {"emo_arousal":"PC1","emo_approach":"PC2","emo_valence":"PC3"}

KEYS = ["subject_id","run_id","trial_in_run","number"]

# =========================
# 1) Utils
# =========================
def _standardize_columns(df: pd.DataFrame) -> pd.DataFrame:
    rename = {}
    cols = set(df.columns)

    if "subject_id" not in cols:
        for c in ["subject","participant","subj","subj_id"]:
            if c in cols:
                rename[c] = "subject_id"; break
    if "run_id" not in cols:
        for c in ["run","runid","run_idx"]:
            if c in cols:
                rename[c] = "run_id"; break
    if "trial_in_run" not in cols:
        for c in ["trial","trial_idx","trialIndex"]:
            if c in cols:
                rename[c] = "trial_in_run"; break
    if "number" not in cols:
        for c in ["sound_number","soundNo","sound_no","stim_number"]:
            if c in cols:
                rename[c] = "number"; break
    return df.rename(columns=rename) if rename else df

def _require_cols(df: pd.DataFrame, required: list[str], name: str):
    miss = [c for c in required if c not in df.columns]
    if miss:
        raise RuntimeError(f"[{name}] Missing columns: {miss}\ncols={list(df.columns)}")

def _to_int_from_any(s: pd.Series) -> pd.Series:
    def conv(x):
        if pd.isna(x): return np.nan
        if isinstance(x, (int, np.integer)): return int(x)
        if isinstance(x, (float, np.floating)) and np.isfinite(x): return int(x)
        st = str(x).strip()
        m = re.search(r"(\d+)", st)
        return int(m.group(1)) if m else np.nan
    return s.apply(conv).astype("Int64")

def benjamini_hochberg(pvals: np.ndarray) -> np.ndarray:
    p = np.asarray(pvals, dtype=float)
    n = p.size
    order = np.argsort(p)
    ranked = p[order]
    q = ranked * n / (np.arange(n) + 1)
    q = np.minimum.accumulate(q[::-1])[::-1]
    q = np.clip(q, 0, 1)
    out = np.empty_like(q)
    out[order] = q
    return out

def signflip_perm_t(diff_by_sub: np.ndarray, n_perm: int = 9999, seed: int = 0):
    rng = np.random.default_rng(seed)
    x = np.asarray(diff_by_sub, float)
    x = x[~np.isnan(x)]
    n = x.size
    if n < 3:
        return np.nan, np.nan
    mu = x.mean()
    sd = x.std(ddof=1)
    t_obs = mu / (sd / np.sqrt(n)) if sd > 0 else np.inf
    flips = rng.choice([-1, 1], size=(n_perm, n))
    xp = flips * x[None, :]
    mu_p = xp.mean(axis=1)
    sd_p = xp.std(axis=1, ddof=1)
    t_p = mu_p / (sd_p / np.sqrt(n))
    p_perm = (np.sum(np.abs(t_p) >= np.abs(t_obs)) + 1) / (n_perm + 1)
    return t_obs, p_perm

def dz_effect(diff_by_sub: np.ndarray) -> float:
    x = np.asarray(diff_by_sub, float)
    x = x[~np.isnan(x)]
    if x.size < 2:
        return np.nan
    sd = x.std(ddof=1)
    return (x.mean() / sd) if sd > 0 else np.nan

def ensure_montage(epochs: mne.Epochs) -> mne.Epochs:
    epochs = epochs.copy()
    montage = mne.channels.make_standard_montage("standard_1020")
    epochs.set_montage(montage, on_missing="ignore")
    return epochs

def _pick_col_by_keywords(df: pd.DataFrame, patterns: list[str], prefer_exact: list[str] | None = None):
    cols = list(df.columns)
    if prefer_exact:
        for c in prefer_exact:
            if c in df.columns:
                return c
    pat = re.compile("|".join(patterns), flags=re.IGNORECASE)
    hits = [c for c in cols if pat.search(str(c))]
    if not hits:
        return None
    hits = sorted(hits, key=lambda x: (len(x), x))
    return hits[0]

# =========================
# 2) FIX: EEGチャンネル名正規化＋19ch抽出
# =========================
def normalize_and_pick_eeg19(epochs: mne.Epochs):
    """
    epochs に POL等が混在していても、
    'EEG Fp1-Ref' -> 'Fp1' にリネームし、19ch（ROI対象）のみを残す。
    """
    epochs = epochs.copy()

    # まず EEG を含むチャンネルだけ候補として残す（POL除外）
    eeg_like = [ch for ch in epochs.ch_names if str(ch).startswith("EEG ")]
    if len(eeg_like) == 0:
        # すでにFp1形式ならそのまま
        eeg_like = [ch for ch in epochs.ch_names if ch in EEG19]

    if len(eeg_like) == 0:
        raise RuntimeError("EEGチャンネルが見つかりません（ch_namesに 'EEG ' も EEG19 も無い）。")

    epochs = epochs.pick_channels(eeg_like)

    # 'EEG Fp1-Ref' -> 'Fp1'
    rename = {}
    for ch in epochs.ch_names:
        s = str(ch)
        m = re.match(r"EEG\s+([A-Za-z0-9]+)(?:-Ref)?$", s)
        if m:
            rename[ch] = m.group(1)
        else:
            # それ以外も最後のトークンで拾う
            # 例: "EEG Fp1-Ref"以外が混じった場合の保険
            s2 = s.replace("EEG", "").strip()
            s2 = re.sub(r"-Ref$", "", s2)
            rename[ch] = s2

    epochs.rename_channels(rename)

    # ROIに必要な19chだけ残す
    keep = [ch for ch in EEG19 if ch in epochs.ch_names]
    if len(keep) < 10:
        raise RuntimeError(f"19chが十分に揃っていません。keep={keep}, all={epochs.ch_names}")

    epochs = epochs.pick_channels(keep)

    # 参照電極などが残っていれば削除（念のため）
    drop = [ch for ch in epochs.ch_names if ch not in EEG19]
    if drop:
        epochs = epochs.drop_channels(drop)

    return epochs

# =========================
# 3) Align epochs <-> trial_feat
# =========================
def align_epochs_to_trial_feat(epochs: mne.Epochs, trial_feat_csv: Path):
    if trial_feat_csv is None or (not trial_feat_csv.exists()):
        print("[WARN] TRIAL_FEAT not found. skip alignment.")
        return epochs, {"aligned": False, "reason": "trial_feat missing"}

    if epochs.metadata is None:
        print("[WARN] epochs.metadata is None. skip alignment.")
        return epochs, {"aligned": False, "reason": "epochs.metadata None"}

    tf = pd.read_csv(trial_feat_csv)
    tf = _standardize_columns(tf)
    _require_cols(tf, KEYS, "trial_feat")

    meta = _standardize_columns(epochs.metadata.copy())
    _require_cols(meta, KEYS, "epochs.metadata")

    for k in KEYS:
        meta[k] = _to_int_from_any(meta[k])
        tf[k]   = _to_int_from_any(tf[k])

    print("NaN rate meta keys:\n", meta[KEYS].isna().mean())
    print("NaN rate tf keys:\n", tf[KEYS].isna().mean())

    meta = meta.dropna(subset=KEYS).copy()
    tf   = tf.dropna(subset=KEYS).copy()

    for k in KEYS:
        meta[k] = meta[k].astype(int)
        tf[k]   = tf[k].astype(int)

    meta = meta.reset_index(drop=True)
    meta["_eidx"] = np.arange(len(meta))
    keep = tf[KEYS].drop_duplicates().copy()
    keep["_keep"] = True

    m = meta.merge(keep, on=KEYS, how="inner")
    roff_best, toff_best, best_n = 0, 0, len(m)

    if len(m) == 0:
        print("[WARN] direct merge=0. trying offset search ...")
        for roff in [-2,-1,0,1,2]:
            for toff in [-2,-1,0,1,2]:
                meta2 = meta.copy()
                meta2["run_id"] = meta2["run_id"] + roff
                meta2["trial_in_run"] = meta2["trial_in_run"] + toff
                mm = meta2.merge(keep, on=KEYS, how="inner")
                if len(mm) > best_n:
                    best_n = len(mm); roff_best, toff_best = roff, toff
        print(f"[BEST] nmatch={best_n}, run_off={roff_best}, trial_off={toff_best}")

        if best_n == 0:
            print("[WARN] alignment failed (0 match). Proceed WITHOUT alignment.")
            return epochs, {"aligned": False, "reason": "0 match even after offset", "run_off": roff_best, "trial_off": toff_best}

        meta["run_id"] = meta["run_id"] + roff_best
        meta["trial_in_run"] = meta["trial_in_run"] + toff_best
        m = meta.merge(keep, on=KEYS, how="inner")

    m = m.sort_values("_eidx").reset_index(drop=True)
    sel = m["_eidx"].to_numpy()

    epochs2 = epochs[sel]
    meta2 = meta.loc[sel].drop(columns=["_eidx"]).reset_index(drop=True)
    epochs2.metadata = meta2

    print(f"Aligned epochs: {len(epochs2)} / {len(epochs)} (run_off={roff_best}, trial_off={toff_best})")
    return epochs2, {"aligned": True, "run_off": roff_best, "trial_off": toff_best, "n_aligned": len(epochs2)}

# =========================
# 4) Attach labels from MASTER (robust)
# =========================
def attach_labels_from_master_robust(epochs: mne.Epochs, master_csv: Path,
                                    pc_fallback_map: dict | None = None,
                                    make_high_by: str = "median"):
    if epochs.metadata is None:
        raise RuntimeError("epochs.metadata が None です。")

    master = pd.read_csv(master_csv)
    master = _standardize_columns(master)
    meta = _standardize_columns(epochs.metadata.copy())

    _require_cols(master, ["subject_id","number"], "MASTER")
    _require_cols(meta,   ["subject_id","number"], "epochs.metadata")

    master["subject_id"] = _to_int_from_any(master["subject_id"])
    master["number"]     = _to_int_from_any(master["number"])
    meta["subject_id"]   = _to_int_from_any(meta["subject_id"])
    meta["number"]       = _to_int_from_any(meta["number"])

    master = master.dropna(subset=["subject_id","number"]).copy()
    meta   = meta.dropna(subset=["subject_id","number"]).copy()
    master["subject_id"] = master["subject_id"].astype(int)
    master["number"]     = master["number"].astype(int)
    meta["subject_id"]   = meta["subject_id"].astype(int)
    meta["number"]       = meta["number"].astype(int)

    if pc_fallback_map is None:
        pc_fallback_map = PC_FALLBACK_MAP

    # 既存列を探す
    col_approach = _pick_col_by_keywords(master,
        patterns=[r"\bemo[_ ]?approach\b", r"\bapproach\b", r"接近", r"avoid", r"回避"],
        prefer_exact=["emo_approach","approach","Approach"]
    )

    # PC fallback候補
    pc_cols = [c for c in master.columns if re.search(r"\bPC\s*\d+\b|\bPC\d+\b|\bpc\d+\b", str(c))]
    pc_cols_sorted = sorted(pc_cols, key=lambda x: (len(str(x)), str(x)))

    def resolve_pc(name):
        want = pc_fallback_map.get(name, None)
        if want and want in master.columns:
            return want
        for cand in ["PC1","PC2","PC3","pc1","pc2","pc3"]:
            if cand in master.columns:
                return cand
        return pc_cols_sorted[0] if pc_cols_sorted else None

    if col_approach is None:
        pc = resolve_pc("emo_approach")
        if pc:
            print(f"[INFO] emo_approach not found -> fallback to {pc}")
            col_approach = pc

    if col_approach and col_approach in master.columns:
        master["emo_approach"] = pd.to_numeric(master[col_approach], errors="coerce")

    # 二値列：既存が無ければmedian splitで作る
    if "emo_approach_high" not in master.columns and "emo_approach" in master.columns:
        x = pd.to_numeric(master["emo_approach"], errors="coerce")
        thr = float(x.median()) if make_high_by == "median" else float(x.mean())
        master["emo_approach_high"] = (x >= thr).astype(int)

    # is_ambiguous / category_3
    col_amb = _pick_col_by_keywords(master,
        patterns=[r"\bis[_ ]?ambiguous\b", r"ambig", r"曖昧", r"ambiguity"],
        prefer_exact=["is_ambiguous","ambiguous"]
    )
    if col_amb and col_amb in master.columns:
        master["is_ambiguous"] = master[col_amb].astype(bool)

    if "category_3" not in master.columns:
        col_cat = _pick_col_by_keywords(master, patterns=[r"\bcategory\b", r"カテゴリ", r"\bcat\b"], prefer_exact=["category"])
        if col_cat and col_cat in master.columns:
            master["category_3"] = master[col_cat]

    want_cols = [c for c in ["emo_approach","emo_approach_high","is_ambiguous","category_3"] if c in master.columns]
    if len(want_cols) == 0:
        raise RuntimeError("MASTERからラベル列を検出できませんでした。列名を確認してください。")

    lab = master[["subject_id","number"] + want_cols].drop_duplicates(["subject_id","number"])
    out = meta.merge(lab, on=["subject_id","number"], how="left")

    print("Attached label NaN rate:\n", out[want_cols].isna().mean().sort_values(ascending=False))
    print("Attached label value_counts (binary-ish):")
    for c in ["emo_approach_high","is_ambiguous"]:
        if c in out.columns:
            print(" ", c, dict(pd.Series(out[c]).value_counts(dropna=False)))

    epochs2 = epochs.copy()
    epochs2.metadata = out.reset_index(drop=True)
    return epochs2, want_cols

# =========================
# 5) Contrast builder
# =========================
def build_contrasts(epochs):
    cand = [
        ("emo_approach_high", "Approach High vs Low", 1, 0),
        ("is_ambiguous",      "Ambiguous vs Non", True, False),
    ]
    return [c for c in cand if c[0] in epochs.metadata.columns]

# =========================
# 6) ROI time series (subject-balanced)
# =========================
def compute_roi_timeseries_by_subject(epochs, cond_col, hi_val, lo_val, roi_name):
    meta = epochs.metadata.copy()
    _require_cols(meta, ["subject_id", cond_col], "epochs.metadata")

    subs = np.sort(pd.to_numeric(meta["subject_id"], errors="coerce").dropna().astype(int).unique())
    times = epochs.times

    picks = [ch for ch in ROI_DEF[roi_name] if ch in epochs.ch_names]
    if len(picks) == 0:
        raise RuntimeError(f"ROI '{roi_name}' channels not found. ch_names={epochs.ch_names}")

    ep_roi = epochs.copy().pick_channels(picks)

    hi_ts, lo_ts = [], []
    for sid in subs:
        msub = (pd.to_numeric(meta["subject_id"], errors="coerce").astype("Int64") == sid)
        m_hi = msub & (meta[cond_col] == hi_val)
        m_lo = msub & (meta[cond_col] == lo_val)

        if m_hi.sum() == 0 or m_lo.sum() == 0:
            hi_ts.append(np.full(times.shape, np.nan))
            lo_ts.append(np.full(times.shape, np.nan))
            continue

        e_hi = ep_roi[m_hi.to_numpy()].average().data.mean(axis=0)
        e_lo = ep_roi[m_lo.to_numpy()].average().data.mean(axis=0)
        hi_ts.append(e_hi); lo_ts.append(e_lo)

    hi_ts = np.vstack(hi_ts)
    lo_ts = np.vstack(lo_ts)
    diff_ts = hi_ts - lo_ts
    return subs, times, hi_ts, lo_ts, diff_ts

# =========================
# 7) Permutation table
# =========================
def make_perm_table(epochs, contrasts):
    rows = []
    for cond_col, contrast_name, hi_val, lo_val in contrasts:
        vc = epochs.metadata[cond_col].value_counts(dropna=False)
        print(f"{cond_col} value_counts:", dict(vc))
        if (epochs.metadata[cond_col] == hi_val).sum() == 0 or (epochs.metadata[cond_col] == lo_val).sum() == 0:
            print(f"[SKIP] {cond_col}: hi/lo not both present.")
            continue

        for roi in ROI_DEF.keys():
            subs, times, hi_ts, lo_ts, diff_ts = compute_roi_timeseries_by_subject(
                epochs, cond_col, hi_val, lo_val, roi
            )
            t_ms = times * 1000.0

            for (a,b) in WINDOWS_MS:
                mask = (t_ms >= a) & (t_ms < b)
                if mask.sum() == 0:
                    continue

                d = np.nanmean(diff_ts[:, mask], axis=1)
                n_sub = int(np.sum(~np.isnan(d)))
                if n_sub < 3:
                    continue

                t_obs, p_perm = signflip_perm_t(d, n_perm=N_PERM, seed=SEED)
                dz = dz_effect(d)

                hi_m = np.nanmean(np.nanmean(hi_ts[:, mask], axis=1))
                lo_m = np.nanmean(np.nanmean(lo_ts[:, mask], axis=1))
                md  = np.nanmean(d)
                sd  = np.nanstd(d, ddof=1)

                if SCALE_TO_uV:
                    hi_m *= 1e6; lo_m *= 1e6; md *= 1e6; sd *= 1e6

                rows.append({
                    "cond": cond_col,
                    "contrast": contrast_name,
                    "roi": roi,
                    "window_ms": f"{a}-{b}",
                    "n_subjects": n_sub,
                    "hi_mean_uV": hi_m,
                    "lo_mean_uV": lo_m,
                    "mean_diff_uV": md,
                    "sd_diff_uV": sd,
                    "T_obs": t_obs,
                    "p_perm": p_perm,
                    "dz": dz
                })

    df = pd.DataFrame(rows)
    if len(df) == 0:
        raise RuntimeError("Permutation table is empty. ラベル/条件/ROIを確認してください。")

    df["p_fdr_within_cond"] = np.nan
    for c in df["cond"].unique():
        idx = df["cond"] == c
        df.loc[idx, "p_fdr_within_cond"] = benjamini_hochberg(df.loc[idx, "p_perm"].to_numpy())
    df["p_fdr_all"] = benjamini_hochberg(df["p_perm"].to_numpy())
    return df

# =========================
# 8) Plots
# =========================
def plot_roi_waveform_with_inset(times, hi_ts, lo_ts, title, out_png, sig_bands=None):
    t_ms = times * 1000.0
    hi_mean = np.nanmean(hi_ts, axis=0)
    lo_mean = np.nanmean(lo_ts, axis=0)

    def sem(x):
        n = np.sum(~np.isnan(x), axis=0)
        sd = np.nanstd(x, axis=0, ddof=1)
        return sd / np.sqrt(np.maximum(n, 1))

    hi_sem = sem(hi_ts)
    lo_sem = sem(lo_ts)

    ylab = "Amplitude (µV)" if SCALE_TO_uV else "Amplitude (V)"
    if SCALE_TO_uV:
        hi_mean *= 1e6; lo_mean *= 1e6
        hi_sem  *= 1e6; lo_sem  *= 1e6

    fig = plt.figure(figsize=(14,5), dpi=180)
    ax = fig.add_subplot(1,1,1)

    ax.plot(t_ms, hi_mean, label="High", linewidth=2.0)
    ax.plot(t_ms, lo_mean, label="Low", linewidth=2.0)
    ax.fill_between(t_ms, hi_mean-hi_sem, hi_mean+hi_sem, alpha=0.18)
    ax.fill_between(t_ms, lo_mean-lo_sem, lo_mean+lo_sem, alpha=0.18)

    ax.axvline(0, linewidth=1.2)
    ax.axhline(0, linewidth=0.8)
    ax.set_xlim([-200, 1000])
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel(ylab)
    ax.set_title(title)
    ax.grid(True, alpha=0.25)
    ax.legend(loc="center right")

    if sig_bands:
        for (a,b,label) in sig_bands:
            ax.axvspan(a, b, alpha=0.12)
            ax.text((a+b)/2, ax.get_ylim()[1]*0.92, label, ha="center", va="top", fontsize=9)

    axins = fig.add_axes([0.62, 0.58, 0.33, 0.33])
    axins.plot(t_ms, hi_mean, linewidth=1.6)
    axins.plot(t_ms, lo_mean, linewidth=1.6)
    axins.set_xlim([0, 250])
    axins.axvline(0, linewidth=1.0)
    axins.axhline(0, linewidth=0.6)
    axins.grid(True, alpha=0.25)

    if sig_bands:
        for (a,b,_) in sig_bands:
            if b <= 250:
                axins.axvspan(a, b, alpha=0.12)

    fig.savefig(out_png, bbox_inches="tight")
    plt.close(fig)

def topomap_subject_balanced(epochs, cond_col, hi_val, lo_val, win_ms, out_png):
    meta = epochs.metadata.copy()
    _require_cols(meta, ["subject_id", cond_col], "epochs.metadata")

    epochs2 = ensure_montage(epochs)

    subs = np.sort(pd.to_numeric(meta["subject_id"], errors="coerce").dropna().astype(int).unique())
    a,b = win_ms
    tmin = a/1000.0
    tmax = (b - 1e-6)/1000.0

    diffs = []
    info_ref = None

    for sid in subs:
        msub = (pd.to_numeric(meta["subject_id"], errors="coerce").astype("Int64") == sid)
        m_hi = msub & (meta[cond_col] == hi_val)
        m_lo = msub & (meta[cond_col] == lo_val)
        if m_hi.sum() == 0 or m_lo.sum() == 0:
            continue
        e_hi = epochs2[m_hi.to_numpy()].average()
        e_lo = epochs2[m_lo.to_numpy()].average()
        diffs.append(e_hi.data - e_lo.data)
        if info_ref is None:
            info_ref = e_hi.info

    if len(diffs) < 3:
        return

    diff_mean = np.mean(np.stack(diffs, axis=0), axis=0)
    times = epochs2.times
    mask = (times >= tmin) & (times < tmax)
    topo = diff_mean[:, mask].mean(axis=1)

    if SCALE_TO_uV:
        topo *= 1e6

    fig, ax = plt.subplots(figsize=(5,4), dpi=180)
    mne.viz.plot_topomap(topo, info_ref, axes=ax, show=False, contours=0)
    ax.set_title(f"{cond_col}: High−Low ({a}-{b} ms)")
    fig.savefig(out_png, bbox_inches="tight")
    plt.close(fig)

# =========================
# 9) RUN
# =========================
print("EPOCHS:", EPOCHS_PATH)
print("MASTER:", MASTER_PATH)
print("TRIAL_FEAT:", TRIAL_FEAT_PATH)
print("OUT:", OUT_DIR)

# Load epochs (preload必須)
epochs = mne.read_epochs(EPOCHS_PATH, preload=True, verbose="ERROR")
print("Loaded epochs:", len(epochs))

# Align to trial_feat
epochs, align_info = align_epochs_to_trial_feat(epochs, TRIAL_FEAT_PATH)
print("align_info:", align_info)

# Attach labels
epochs, attached_cols = attach_labels_from_master_robust(epochs, MASTER_PATH, pc_fallback_map=PC_FALLBACK_MAP)
print("Attached cols:", attached_cols)

# FIX: EEG19だけにして、名前をFp1等に正規化
epochs = normalize_and_pick_eeg19(epochs)
print("After normalize/pick EEG19 ch_names:", epochs.ch_names)

# Onset ERP preprocessing
epochs = ensure_montage(epochs)
epochs = epochs.copy().crop(-0.2, 1.0).apply_baseline((-0.2, 0.0))
print("After crop/baseline:", epochs.tmin, epochs.tmax, "n=", len(epochs))

# Contrasts
contrasts = build_contrasts(epochs)
print("Available CONTRASTS:", [c[0] for c in contrasts])
if len(contrasts) == 0:
    raise RuntimeError("使えるコントラストが0です。")

# Permutation table
df = make_perm_table(epochs, contrasts)
out_csv = OUT_TAB / "ERP_onset_under1000ms_perm_ALLCONDS.csv"
df.to_csv(out_csv, index=False)
print("[SAVED]", out_csv)

# Waveform plots
df_early = df[df["window_ms"].isin([f"{a}-{b}" for a,b in EARLY_WINDOWS_MS])].copy()

for cond_col, contrast_name, hi_val, lo_val in contrasts:
    for roi in ROI_DEF.keys():
        subs, times, hi_ts, lo_ts, diff_ts = compute_roi_timeseries_by_subject(
            epochs, cond_col, hi_val, lo_val, roi
        )

        sig_bands = []
        for (a,b) in EARLY_WINDOWS_MS:
            w = f"{a}-{b}"
            hit = df_early[(df_early["cond"]==cond_col) & (df_early["roi"]==roi) & (df_early["window_ms"]==w)]
            if len(hit)==1 and float(hit["p_perm"].iloc[0]) < 0.05:
                sig_bands.append((a,b,f"{w} ms (p<.05)"))

        out_png = OUT_FIG / f"ERP_{cond_col}_{roi}_highlow_-200_1000ms.png"
        plot_roi_waveform_with_inset(
            times, hi_ts, lo_ts,
            title=f"{roi.upper()} ROI ERP: {contrast_name} (−200 to 1000 ms)",
            out_png=out_png,
            sig_bands=sig_bands
        )
        print("[SAVED]", out_png)

# Topomaps
for cond_col, contrast_name, hi_val, lo_val in contrasts:
    for (a,b) in TOPO_WINDOWS_MS:
        out_png = OUT_FIG / f"TopoDiff_{cond_col}_highminuslow_{a}-{b}ms.png"
        topomap_subject_balanced(epochs, cond_col, hi_val, lo_val, (a,b), out_png)
        if out_png.exists():
            print("[SAVED]", out_png)

print("\nDONE.")
print("Tables:", OUT_TAB)
print("Figs  :", OUT_FIG)


EPOCHS: /Users/shunsuke/EEG_48sounds/derivatives/epochs_all/epochs_all-epo.fif
MASTER: /Users/shunsuke/EEG_48sounds/derivatives/master_tables/master_participant_sound_level_with_PC.csv
TRIAL_FEAT: /Users/shunsuke/EEG_48sounds/moduleB_outputs/tables/moduleB_trial_eeg_features.csv
OUT: /Users/shunsuke/EEG_48sounds/moduleB_outputs/figs/ERP_onset_under1000ms_suite
Loaded epochs: 1728
NaN rate meta keys:
 subject_id      0.0
run_id          0.0
trial_in_run    0.0
number          0.0
dtype: float64
NaN rate tf keys:
 subject_id      0.0
run_id          0.0
trial_in_run    0.0
number          0.0
dtype: float64
Replacing existing metadata with 32 columns
Aligned epochs: 1728 / 1728 (run_off=0, trial_off=0)
align_info: {'aligned': True, 'run_off': 0, 'trial_off': 0, 'n_aligned': 1728}
Attached label NaN rate:
 emo_approach         0.0
emo_approach_high    0.0
is_ambiguous         0.0
category_3           0.0
dtype: float64
Attached label value_counts (binary-ish):
  emo_approach_high {1: np.i