In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
ERP High vs Low figure generator (Arousal / Approach)
=====================================================
- epochs_all-epo.fif を読み込む
- epochs.metadata['number'] と master_sound_level_with_PC.csv['number'] を結合して
  emo_arousal_high / emo_approach_high / emo_valence_high を metadata に追加
  （無ければ emo_arousal / emo_approach から median split で作る）
- 重要特徴量CSV（moduleB_importance_*.csv）から ERP の「重要時間窓(ms)」を抽出し、
  High/Low ERP波形に重ねて可視化（0〜4000msのみ表示）
- 各チャンネル×各窓の平均振幅差（High - Low, µV）をCSVに保存

想定ファイル:
ROOT/
  derivatives/epochs_all/epochs_all-epo.fif
  derivatives/master_tables/master_sound_level_with_PC.csv
  moduleB_importance_emo_arousal_high_linear.csv
  moduleB_importance_emo_approach_high_linear.csv
"""

from __future__ import annotations
import re
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
from matplotlib import font_manager

def set_matplotlib_style(font_scale: float = 2.0):
    """
    図5/6で「文字が小さくて読めない」を潰すための強制スタイル。
    Word貼り付けで縮小されても潰れにくい。
    """
    base = 10 * font_scale

    # 日本語が入る可能性がある図（タイトル等）対策：使える日本語フォントを自動選択
    jp_candidates = [
        "Hiragino Sans", "Yu Gothic", "YuGothic", "Meiryo",
        "Noto Sans CJK JP", "IPAexGothic"
    ]
    available = {f.name for f in font_manager.fontManager.ttflist}
    for f in jp_candidates:
        if f in available:
            plt.rcParams["font.family"] = f
            break
    plt.rcParams["axes.unicode_minus"] = False

    plt.rcParams.update({
        "font.size": base,
        "axes.titlesize": base * 1.15,
        "axes.labelsize": base * 1.10,
        "xtick.labelsize": base * 1.00,
        "ytick.labelsize": base * 1.00,
        "legend.fontsize": base * 1.00,
        "figure.titlesize": base * 1.35,
        "lines.linewidth": 2.6,
    })


# =========================
# ユーザー設定（ここだけ要確認）
# =========================
ROOT = Path("/Users/shunsuke/EEG_48sounds")

EPOCHS_FIF = ROOT / "derivatives/epochs_all/epochs_all-epo.fif"
MASTER_SOUND = ROOT / "derivatives/master_tables/master_sound_level_with_PC.csv"

IMP_AROUSAL = ROOT / "moduleB_importance_emo_arousal_high_linear.csv"
IMP_APPROACH = ROOT / "moduleB_importance_emo_approach_high_linear.csv"

OUT_DIR = ROOT / "paper_figs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# 波形表示範囲（ms）
PLOT_TMIN_MS = -200
PLOT_TMAX_MS = 4000

# ベースライン（秒）※epochsが -2s〜 ある前提
BASELINE_SEC = (-0.2, 0.0)

# 代表チャンネル（論文に出す用）
AROUSAL_CHS = ["P4", "Pz", "F3", "F8"]
APPROACH_CHS = ["T7", "Fp2", "F8", "P4"]  # T7はデータ上T3になってることが多いので自動で吸収


# =========================
# ユーティリティ
# =========================
def _assert_exists(path: Path) -> None:
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path}")


def add_sound_labels_by_number(
    epochs: mne.Epochs,
    master_sound_csv: Path,
    key_epochs: str = "number",
    key_master: str = "number",
) -> mne.Epochs:
    """
    epochs.metadata[key_epochs] と master_sound[key_master] を結合して
    emo_* を epochs.metadata に追加する。
    """
    _assert_exists(master_sound_csv)

    if epochs.metadata is None:
        raise RuntimeError("epochs.metadata が None です。Epochs作成時に metadata を付与してください。")

    md = epochs.metadata.copy()
    if key_epochs not in md.columns:
        raise KeyError(f"epochs.metadata に '{key_epochs}' がありません。metadata columns={list(md.columns)}")

    master = pd.read_csv(master_sound_csv)

    if key_master not in master.columns:
        raise KeyError(f"master_sound に '{key_master}' がありません。master columns={list(master.columns)}")

    # 型を揃える（numberは数値のことが多い）
    md[key_epochs] = pd.to_numeric(md[key_epochs], errors="coerce")
    master[key_master] = pd.to_numeric(master[key_master], errors="coerce")

    # 取り込みたい列（存在するものだけ）
    want_cols = [
        key_master,
        "emo_arousal_high",
        "emo_approach_high",
        "emo_valence_high",
        "emo_arousal",
        "emo_approach",
        "emo_valence",
    ]
    take = [c for c in want_cols if c in master.columns]

    merged = md.merge(
        master[take],
        left_on=key_epochs,
        right_on=key_master,
        how="left",
        validate="many_to_one",
    )

    # match率チェック
    match_rate = merged[key_master].notna().mean()
    print(f"[OK] merge key: metadata.{key_epochs} <-> master.{key_master} | match_rate={match_rate:.3f}")

    # high列が無い場合は median split で作る（48音のsound-level想定）
    def ensure_high(col_raw: str, col_high: str) -> None:
        if col_high in merged.columns and merged[col_high].notna().any():
            return
        if col_raw in merged.columns and merged[col_raw].notna().any():
            med = merged[col_raw].median()
            merged[col_high] = (merged[col_raw] >= med).astype(int)
            print(f"[make] {col_high} from {col_raw} (median split, median={med:.3f})")
        else:
            print(f"[warn] {col_high} を作れません（{col_raw} も {col_high} も見つからない）")

    ensure_high("emo_arousal", "emo_arousal_high")
    ensure_high("emo_approach", "emo_approach_high")
    ensure_high("emo_valence", "emo_valence_high")

    # epochsへ反映
    merged = merged.drop(columns=[key_master], errors="ignore")
    epochs = epochs.copy()
    epochs.metadata = merged
    print(f"[metadata] columns={len(epochs.metadata.columns)} added/kept.")
    return epochs


def load_importance_windows_ms(importance_csv: Path, topk: int = 10) -> list[tuple[int, int]]:
    """
    moduleB_importance_*.csv から、ERPの時間窓(ms)のみ抽出。
    例: ERP_P4_2000_3000ms -> (2000,3000)
    表示の都合で 0〜4000ms に重なる窓のみ返す。
    """
    if not importance_csv.exists():
        print(f"[warn] importance csv not found: {importance_csv} -> fallback manual windows")
        return []

    df = pd.read_csv(importance_csv)
    if "feature" not in df.columns:
        print(f"[warn] 'feature' col not found in {importance_csv.name}")
        return []

    df = df.head(topk).copy()

    wins = []
    for feat in df["feature"].astype(str).tolist():
        if not feat.startswith("ERP_"):
            continue
        m = re.match(r"^ERP_[^_]+_(?P<t1>-?\d+)_(?P<t2>-?\d+)ms$", feat)
        if not m:
            continue
        t1 = int(m.group("t1"))
        t2 = int(m.group("t2"))
        # 0〜4000msに重なる窓だけ
        if t2 < 0 or t1 > 4000:
            continue
        wins.append((t1, t2))

    # 順序を保ってユニーク化
    uniq = []
    seen = set()
    for w in wins:
        if w not in seen:
            uniq.append(w)
            seen.add(w)

    return uniq


def resolve_channel_name(ch_names: list[str], short: str) -> str:
    """
    例:
      short='P4' -> 'EEG P4-Ref' 等に解決
      short='T7' -> 多くのEEGでは 'T3' 表記なので自動で寄せる
    """
    alias = {
        "T7": "T3",
        "T8": "T4",
    }
    targets = [short]
    if short in alias:
        targets.append(alias[short])

    # 候補パターン（強い順）
    patterns = []
    for t in targets:
        patterns += [
            rf"^{re.escape(t)}$",
            rf"^EEG\s+{re.escape(t)}$",
            rf"^EEG\s+{re.escape(t)}-",
            rf".*\b{re.escape(t)}-Ref\b.*",
            rf".*\b{re.escape(t)}\b.*",
        ]

    for pat in patterns:
        rx = re.compile(pat)
        hits = [c for c in ch_names if rx.match(c) or rx.search(c)]
        if len(hits) >= 1:
            return hits[0]

    raise ValueError(
        f"チャンネル '{short}' を解決できません。"
        f"利用可能ch例: {ch_names[:10]} ... total={len(ch_names)}"
    )


def compute_window_diff_uv(
    ev_hi: mne.Evoked,
    ev_lo: mne.Evoked,
    ch_short: str,
    windows_ms: list[tuple[int, int]],
) -> pd.DataFrame:
    """
    各窓の平均振幅（µV）の High-Low 差分を返す
    """
    ch_name = resolve_channel_name(ev_hi.ch_names, ch_short)
    times_ms = ev_hi.times * 1000.0

    y_hi = ev_hi.copy().pick(ch_name).data[0] * 1e6
    y_lo = ev_lo.copy().pick(ch_name).data[0] * 1e6

    rows = []
    for (a, b) in windows_ms:
        m = (times_ms >= a) & (times_ms <= b)
        if m.sum() == 0:
            continue
        rows.append(
            {
                "channel_short": ch_short,
                "channel_in_data": ch_name,
                "window_ms": f"{a}-{b}",
                "mean_high_uV": float(np.mean(y_hi[m])),
                "mean_low_uV": float(np.mean(y_lo[m])),
                "diff_high_minus_low_uV": float(np.mean(y_hi[m]) - np.mean(y_lo[m])),
            }
        )
    return pd.DataFrame(rows)


def plot_erp_hilo(
    epochs: mne.Epochs,
    label_col: str,
    chs_short: list[str],
    windows_ms: list[tuple[int, int]],
    out_png: Path,
    title: str,
    *,
    font_scale: float = 2.0,
    figsize: tuple[float, float] = (16, 10),
    dpi: int = 450,
) -> None:
    """
    - 2×2固定（chs_shortは4ch前提）
    - 全サブプロットで軸の数値（tick）と凡例を表示
    - フォント・dpiを上げて「Word貼り付けで潰れる」を回避
    - PNG + PDF を保存（PDFは拡大しても文字が劣化しにくい）
    """
    if epochs.metadata is None or label_col not in epochs.metadata.columns:
        raise KeyError(
            f"metadataに '{label_col}' が必要です。metadata cols={list(epochs.metadata.columns) if epochs.metadata is not None else None}"
        )
    if len(chs_short) != 4:
        raise ValueError("2×2表示のため chs_short は4チャンネルにしてください。")

    set_matplotlib_style(font_scale=font_scale)

    # EEGのみ + baseline
    ep = epochs.copy().pick_types(eeg=True)
    if BASELINE_SEC is not None:
        ep = ep.apply_baseline(BASELINE_SEC)

    md = ep.metadata
    ep_hi = ep[md[label_col] == 1]
    ep_lo = ep[md[label_col] == 0]
    ev_hi = ep_hi.average()
    ev_lo = ep_lo.average()

    times_ms = ev_hi.times * 1000.0
    m_plot = (times_ms >= PLOT_TMIN_MS) & (times_ms <= PLOT_TMAX_MS)
    x = times_ms[m_plot]

    # ---- 先に全chのy範囲を計算して、全パネルで同じylimにする（比較しやすい＋見やすい） ----
    y_stack = []
    for ch_short in chs_short:
        ch_name = resolve_channel_name(ev_hi.ch_names, ch_short)
        y_hi = ev_hi.copy().pick(ch_name).data[0, m_plot] * 1e6
        y_lo = ev_lo.copy().pick(ch_name).data[0, m_plot] * 1e6
        y_stack.append(y_hi)
        y_stack.append(y_lo)
    y_all = np.concatenate(y_stack)
    lo, hi = np.percentile(y_all, [1, 99])  # 外れ値に引っ張られない
    pad = (hi - lo) * 0.15 if hi > lo else 1.0
    y_min, y_max = lo - pad, hi + pad

    fig, axes = plt.subplots(2, 2, figsize=figsize, constrained_layout=True)
    axes = axes.ravel()

    for ax, ch_short in zip(axes, chs_short):
        ch_name = resolve_channel_name(ev_hi.ch_names, ch_short)

        y_hi = ev_hi.copy().pick(ch_name).data[0, m_plot] * 1e6
        y_lo = ev_lo.copy().pick(ch_name).data[0, m_plot] * 1e6

        ax.plot(x, y_hi, label="High")
        ax.plot(x, y_lo, label="Low")
        ax.axvline(0, linewidth=2.0)

        for (a, b) in windows_ms:
            aa = max(a, PLOT_TMIN_MS)
            bb = min(b, PLOT_TMAX_MS)
            if aa < bb:
                ax.axvspan(aa, bb, alpha=0.12)

        ax.set_title(f"{ch_short} ({ch_name})", pad=8)
        ax.set_xlabel("Time (ms)")
        ax.set_ylabel("Amplitude (µV)")
        ax.set_xlim(PLOT_TMIN_MS, PLOT_TMAX_MS)
        ax.set_ylim(y_min, y_max)

        # ★要件：全サブプロットで数値（tick）を“必ず”表示＆太くする
        ax.tick_params(axis="both", which="major", length=7, width=1.6,
                       labelbottom=True, labelleft=True)

        # ★要件：全サブプロットで凡例
        ax.legend(loc="upper right", frameon=True)

        ax.grid(True, alpha=0.25)

    fig.suptitle(title)

    out_png.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_png, dpi=dpi, bbox_inches="tight")
    fig.savefig(out_png.with_suffix(".pdf"), dpi=dpi, bbox_inches="tight")  # ←論文/要旨はPDFが強い
    plt.close(fig)

    # 窓平均との差分CSVも保存
    if windows_ms:
        dfs = [compute_window_diff_uv(ev_hi, ev_lo, ch, windows_ms) for ch in chs_short]
        out_csv = out_png.with_suffix(".csv")
        pd.concat(dfs, axis=0, ignore_index=True).to_csv(out_csv, index=False, encoding="utf-8-sig")
        print(f"[SAVE] {out_png}")
        print(f"[SAVE] {out_png.with_suffix('.pdf')}")
        print(f"[SAVE] {out_csv}")
    else:
        print(f"[SAVE] {out_png}")
        print(f"[SAVE] {out_png.with_suffix('.pdf')}")




# =========================
# main
# =========================
def main():
    _assert_exists(EPOCHS_FIF)
    _assert_exists(MASTER_SOUND)

    print(f"[LOAD] {EPOCHS_FIF}")
    epochs = mne.read_epochs(EPOCHS_FIF, preload=True)

    # まずラベルを metadata に足す（sound_id不要。number結合で完結）
    epochs = add_sound_labels_by_number(epochs, MASTER_SOUND, key_epochs="number", key_master="number")

    # 重要窓（top10、ERPのみ、0-4000msのみ）
    arousal_windows = load_importance_windows_ms(IMP_AROUSAL, topk=10)
    approach_windows = load_importance_windows_ms(IMP_APPROACH, topk=10)

    # 念のため：窓が空なら手動fallback（あなたの上位特徴に合わせた代表窓）
    if not arousal_windows:
        arousal_windows = [(400, 800), (800, 1200), (2000, 3000)]
    if not approach_windows:
        approach_windows = [(0, 200), (400, 800), (1200, 2000), (2000, 3000), (3000, 4000)]

    plot_erp_hilo(
        epochs,
        label_col="emo_arousal_high",
        chs_short=AROUSAL_CHS,
        windows_ms=arousal_windows,
        out_png=OUT_DIR / "fig5_erp_hilo_arousal.png",
        title="ERP (High vs Low): Arousal",
        font_scale=2.1,
        figsize=(16, 10),
        dpi=500,
    )

    plot_erp_hilo(
        epochs,
        label_col="emo_approach_high",
        chs_short=APPROACH_CHS,
        windows_ms=approach_windows,
        out_png=OUT_DIR / "fig6_erp_hilo_approach.png",
        title="ERP (High vs Low): Approach",
        font_scale=2.1,
        figsize=(16, 10),
        dpi=500,
    )


    print("[DONE] figures ->", OUT_DIR)


if __name__ == "__main__":
    main()


[LOAD] /Users/shunsuke/EEG_48sounds/derivatives/epochs_all/epochs_all-epo.fif
Reading /Users/shunsuke/EEG_48sounds/derivatives/epochs_all/epochs_all-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =   -2000.00 ...   12000.00 ms
        0 CTF compensation matrices available
Reading /Users/shunsuke/EEG_48sounds/derivatives/epochs_all/epochs_all-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =   -2000.00 ...   12000.00 ms
        0 CTF compensation matrices available
Adding metadata with 32 columns
1728 matching events found
No baseline correction applied
0 projection items activated
[OK] merge key: metadata.number <-> master.number | match_rate=1.000
[make] emo_arousal_high from emo_arousal (median split, median=0.343)
[make] emo_approach_high from emo_approach (median split, median=-0.390)
[make] emo_valence_high from emo_valence (median split, median=0.064)
Replacing existing metadata with 37 columns
[metadata] columns=37 added/kept.
[wa