In [8]:
"""
48音EEG用「統合前・最強特徴量」抽出スクリプト（trial → sound → subject）
★統合解析(⑥)でそのままJOINできるように、trial出力に number/FileName を必ず載せる版

入力:
- derivatives/master_epoch_index.csv
- derivatives/epochs_trial/<subject>/fif/<subject>_<run>_trialXX-epo.fif

出力:
- derivatives/eeg_features_trial.csv
    -> 1行 = 1 trial（統合解析の基準テーブルにJOIN可能）
- derivatives/eeg_features_sound.csv
    -> 1行 = 1 sound(number×FileName)（※number/FileNameが揃っている時のみ）
- derivatives/eeg_features_subject.csv
    -> 1行 = 1 participant（subject→Pxx変換後、参加者平均）

注意:
- master_epoch_index に number / FileName が無い場合、sound集約は「trial位置」になってしまうので
  この版では “sound集約を原則スキップ” し、監査ログを出します。
"""

from __future__ import annotations

from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Any, Optional

import numpy as np
import pandas as pd
import mne
from scipy.signal import welch, hilbert
from scipy.stats import pearsonr, spearmanr
from scipy.cluster.vq import kmeans2


# ============================================================
# 0. PATH & GLOBAL SETTINGS
# ============================================================

# ★ここだけ自分の環境に合わせて変更
ROOT_DIR = Path("/Users/shunsuke/EEG_48sounds")

DERIV_DIR = ROOT_DIR / "derivatives"
DERIV_DIR.mkdir(parents=True, exist_ok=True)

# 入力
MASTER_INDEX_CSV = DERIV_DIR / "master_epoch_index.csv"
TRIAL_BASE_DIR   = DERIV_DIR / "epochs_trial"

# 出力
OUT_TRIAL_CSV = DERIV_DIR / "eeg_features_trial.csv"
OUT_SOUND_CSV = DERIV_DIR / "eeg_features_sound.csv"
OUT_SUBJ_CSV  = DERIV_DIR / "eeg_features_subject.csv"

# 監査ログ
AUDIT_DIR = ROOT_DIR / "output" / "integration_audit"
AUDIT_DIR.mkdir(parents=True, exist_ok=True)

# 処理方針
PROCESS_ONLY_QC_PASS = True   # Trueなら qc_pass==True のtrialだけ特徴抽出（推奨）
KEEP_ERROR_ROWS      = True   # Trueならエラーtrialも feature_ok=False で行を残す（統合が安定）

# フェーズ定義（音開始=0, 音長=5s を想定）
PHASES = {
    "A": (-2.0, -0.5),   # 広いベースライン
    "B": (-0.5, 0.0),    # 直前期（SPN/CNV）
    "C": (0.0, 0.5),     # オンセット
    "D": (0.5, 3.0),     # 持続
    "E": (3.0, 5.0),     # 末期
    "F": (5.0, 12.0),    # オフ後（余韻）
}

# 周波数帯
FREQ_BANDS = {
    "delta": (1.0, 4.0),
    "theta": (4.0, 7.0),
    "alpha": (8.0, 13.0),
    "beta":  (13.0, 30.0),
    "gamma": (30.0, 80.0),
}

# ROI（チャネル名の部分一致）
ROI_PATTERNS = {
    "FC": ["Fz", "Cz"],
    "P":  ["Pz", "P3", "P4"],
    "Pz": ["Pz"],
    "Cz": ["Cz"],
    "Fz": ["Fz"],
    "OCC": ["O1", "O2"],
    "F_left":  ["Fp1", "F3", "F7"],
    "F_right": ["Fp2", "F4", "F8"],
}

FAA_LEFT_PATTERNS  = ["Fp1", "F3", "F7"]
FAA_RIGHT_PATTERNS = ["Fp2", "F4", "F8"]

# 結合性のペア（わかりやすい命名）
CONNECTIVITY_PAIRS = {
    "FzCz": (["Fz"], ["Cz"]),
    "FzPz": (["Fz"], ["Pz"]),
    "F3P3": (["F3"], ["P3"]),
    "F4P4": (["F4"], ["P4"]),
    "T3T4": (["T3"], ["T4"]),
}


# ============================================================
# 1. UTILS
# ============================================================

def subject_to_participant(subject: str) -> Optional[str]:
    """
    subject文字列の先頭数字を拾って Pxx に変換。
    例: "1_高見" -> "P01"
    """
    if subject is None:
        return None
    s = str(subject)
    import re
    m = re.match(r"(\d+)", s)
    if not m:
        return None
    return f"P{int(m.group(1)):02d}"


def normalize_run(run: Any) -> str:
    """
    run表記ゆれ（1,2,3 / run1 / Run1 / etc）を run1 形式へ寄せる。
    """
    r = str(run).strip()
    # すでに run1 などならそのまま
    if r.lower().startswith("run"):
        # run01 のような場合も run1 に寄せる
        import re
        m = re.match(r"run\s*0*(\d+)", r, flags=re.IGNORECASE)
        if m:
            return f"run{int(m.group(1))}"
        return r.lower()
    # 数字だけなら runN
    try:
        n = int(float(r))
        return f"run{n}"
    except Exception:
        return r.lower()


def pick_by_patterns(info: mne.Info, patterns: List[str]) -> np.ndarray:
    """チャネル名に指定パターンが含まれるインデックスを返す。"""
    ch_names = np.array(info["ch_names"])
    idx = [i for i, name in enumerate(ch_names) if any(pat in name for pat in patterns)]
    return np.asarray(idx, dtype=int)


def get_time_mask(times: np.ndarray, tmin: float, tmax: float) -> np.ndarray:
    """timesベクトルから [tmin, tmax) のブールマスクを作る。"""
    return (times >= tmin) & (times < tmax)


def baseline_correct(data: np.ndarray, times: np.ndarray, tmin: float = -0.2, tmax: float = 0.0) -> np.ndarray:
    """単一trial(n_channels, n_times)に対するベースライン補正。"""
    mask = get_time_mask(times, tmin, tmax)
    if not np.any(mask):
        return data.copy()
    baseline = data[:, mask].mean(axis=1, keepdims=True)
    return data - baseline


def safe_mean(x: np.ndarray) -> float:
    """空ならNaN、それ以外は平均。"""
    if x.size == 0:
        return np.nan
    return float(np.mean(x))


def spd_logm(C: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    """
    SPD行列用の logm（固有値分解ベース）。
    """
    C_sym = 0.5 * (C + C.T)
    w, V = np.linalg.eigh(C_sym)
    w = np.clip(w, eps, None)
    logw = np.log(w)
    return (V * logw) @ V.T
    
def canonicalize_filename(name: Any) -> Any:
    """
    FileName の表記ゆれを統合解析JOIN向けに正規化。
    - 余計なメモ（例: '.wav 30秒後'）を切り落とす
    - パスが混じっても basename にする
    - NAはNAのまま維持
    """
    if pd.isna(name):
        return pd.NA

    s = str(name).strip()
    # パスが入っていたらファイル名だけに
    s = Path(s).name

    # 連続空白を1つに
    import re
    s = re.sub(r"\s+", " ", s).strip()

    # ".wav" までで切る（大文字小文字無視）
    m = re.search(r"\.wav", s, flags=re.IGNORECASE)
    if m:
        s = s[: m.end()]

    return s


# ============================================================
# 2. ERP / SLOW POTENTIALS
# ============================================================

def erp_peak(data_bc: np.ndarray, times: np.ndarray, info: mne.Info, roi_key: str,
             tmin: float, tmax: float, mode: str = "min") -> float:
    """
    ROI・時間窓でのERPピーク値。
    mode: "min" / "max" / "mean"
    """
    picks = pick_by_patterns(info, ROI_PATTERNS[roi_key])
    if picks.size == 0:
        return np.nan
    mask = get_time_mask(times, tmin, tmax)
    if not np.any(mask):
        return np.nan
    roi_wave = data_bc[picks][:, mask].mean(axis=0)
    if mode == "min":
        return float(np.min(roi_wave))
    if mode == "max":
        return float(np.max(roi_wave))
    if mode == "mean":
        return float(np.mean(roi_wave))
    raise ValueError(f"Unknown mode: {mode}")


def sustained_potential(data_bc: np.ndarray, times: np.ndarray, info: mne.Info,
                        roi_key: str, tmin: float, tmax: float) -> float:
    """ROI・時間窓での平均電位（持続ポテンシャル）。"""
    picks = pick_by_patterns(info, ROI_PATTERNS[roi_key])
    if picks.size == 0:
        return np.nan
    mask = get_time_mask(times, tmin, tmax)
    if not np.any(mask):
        return np.nan
    roi_data = data_bc[picks][:, mask]
    return safe_mean(roi_data)


# ============================================================
# 3. BAND POWER / ERD-ERS / FAA
# ============================================================

def bandpower_roi(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info,
                  roi_key: str, band_name: str, tmin: float, tmax: float, nperseg: int = 256) -> float:
    """Welch法でROIのバンドパワーを算出。"""
    picks = pick_by_patterns(info, ROI_PATTERNS[roi_key])
    if picks.size == 0:
        return np.nan
    mask = get_time_mask(times, tmin, tmax)
    if not np.any(mask):
        return np.nan

    seg = data[picks][:, mask]
    fmin, fmax = FREQ_BANDS[band_name]

    bp_list = []
    for ch in range(seg.shape[0]):
        f, Pxx = welch(seg[ch], fs=sfreq, nperseg=min(nperseg, seg.shape[1]))
        band_mask = (f >= fmin) & (f <= fmax)
        if np.any(band_mask):
            bp_list.append(Pxx[band_mask].mean())

    if len(bp_list) == 0:
        return np.nan
    return float(np.mean(bp_list))


def erd_ers_db(bp_phase: float, bp_baseline: float, eps: float = 1e-12) -> float:
    """ERD/ERSをdB表現で返す。"""
    if np.isnan(bp_phase) or np.isnan(bp_baseline):
        return np.nan
    return float(10.0 * np.log10((bp_phase + eps) / (bp_baseline + eps)))


def faa_alpha(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info,
              tmin: float, tmax: float, nperseg: int = 256) -> float:
    """前頭αバンド左右差（FAA = logα右 − logα左）。"""
    fmin, fmax = FREQ_BANDS["alpha"]

    def bp_side(patterns: List[str]) -> float:
        picks = pick_by_patterns(info, patterns)
        if picks.size == 0:
            return np.nan
        mask = get_time_mask(times, tmin, tmax)
        if not np.any(mask):
            return np.nan
        seg = data[picks][:, mask]
        vals = []
        for ch in range(seg.shape[0]):
            f, Pxx = welch(seg[ch], fs=sfreq, nperseg=min(nperseg, seg.shape[1]))
            band_mask = (f >= fmin) & (f <= fmax)
            if np.any(band_mask):
                vals.append(Pxx[band_mask].mean())
        if len(vals) == 0:
            return np.nan
        return float(np.mean(vals))

    bp_left = bp_side(FAA_LEFT_PATTERNS)
    bp_right = bp_side(FAA_RIGHT_PATTERNS)
    if np.isnan(bp_left) or np.isnan(bp_right):
        return np.nan
    return float(np.log(bp_right) - np.log(bp_left))


# ============================================================
# 4. RIEMANN / CONNECTIVITY / COMPLEXITY
# ============================================================

def compute_riemann_features(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info) -> Dict[str, float]:
    """フェーズごとのlog共分散（上三角）を特徴量として出す。"""
    feats: Dict[str, float] = {}
    phases_for_cov = ["B", "C", "D", "E", "F"]
    n_ch = data.shape[0]
    eps = 1e-6

    for phase_key in phases_for_cov:
        tmin, tmax = PHASES[phase_key]
        mask = get_time_mask(times, tmin, tmax)
        if not np.any(mask):
            continue

        seg = data[:, mask]
        if seg.shape[1] < n_ch + 1:
            continue

        C = np.cov(seg) + eps * np.eye(n_ch)
        try:
            C_log = spd_logm(C, eps=1e-12)
        except Exception:
            C_log = np.full_like(C, np.nan)

        for i in range(n_ch):
            for j in range(i, n_ch):
                feats[f"RG_phase{phase_key}_{i:02d}_{j:02d}"] = float(np.real(C_log[i, j]))

    return feats


def compute_connectivity_features(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info) -> Dict[str, float]:
    """PLV/Coherence を band×phase×pair で算出。"""
    feats: Dict[str, float] = {}

    # bandごとにフィルタしたデータをキャッシュ
    bandpass_data: Dict[str, np.ndarray] = {}
    for band_name in ["theta", "alpha", "beta", "gamma"]:
        fmin, fmax = FREQ_BANDS[band_name]
        bp_data = mne.filter.filter_data(data.copy(), sfreq=sfreq, l_freq=fmin, h_freq=fmax, verbose="ERROR")
        bandpass_data[band_name] = bp_data

    from scipy.signal import coherence as coh_fun

    phases_for_conn = ["C", "D", "E", "F"]
    for band_name, bp_data in bandpass_data.items():
        fmin, fmax = FREQ_BANDS[band_name]
        for phase_key in phases_for_conn:
            tmin, tmax = PHASES[phase_key]
            mask = get_time_mask(times, tmin, tmax)
            if not np.any(mask):
                continue

            for pair_name, (pat1, pat2) in CONNECTIVITY_PAIRS.items():
                picks1 = pick_by_patterns(info, pat1)
                picks2 = pick_by_patterns(info, pat2)
                key_plv = f"PLV_{band_name}_{pair_name}_phase{phase_key}"
                key_coh = f"COH_{band_name}_{pair_name}_phase{phase_key}"

                if picks1.size == 0 or picks2.size == 0:
                    feats[key_plv] = np.nan
                    feats[key_coh] = np.nan
                    continue

                seg1 = bp_data[picks1][:, mask].mean(axis=0)
                seg2 = bp_data[picks2][:, mask].mean(axis=0)
                if seg1.size < 10 or seg2.size < 10:
                    feats[key_plv] = np.nan
                    feats[key_coh] = np.nan
                    continue

                # PLV
                analytic1 = hilbert(seg1)
                analytic2 = hilbert(seg2)
                phase_diff = np.angle(analytic1) - np.angle(analytic2)
                feats[key_plv] = float(np.abs(np.exp(1j * phase_diff).mean()))

                # Coherence
                try:
                    f, coh_vals = coh_fun(seg1, seg2, fs=sfreq, nperseg=min(256, seg1.size))
                    band_mask = (f >= fmin) & (f <= fmax)
                    feats[key_coh] = float(np.mean(coh_vals[band_mask])) if np.any(band_mask) else np.nan
                except Exception:
                    feats[key_coh] = np.nan

    return feats


def sample_entropy(signal: np.ndarray, m: int = 2, r: float | None = None) -> float:
    """Sample Entropy（簡易実装）。"""
    x = np.array(signal, dtype=float)
    N = x.size
    if N <= m + 1:
        return np.nan
    if r is None:
        r = 0.2 * np.std(x)
        if r == 0:
            return np.nan

    def _phi(m_: int) -> float:
        count = 0
        Xm = np.array([x[i:i + m_] for i in range(N - m_ + 1)])
        for i in range(Xm.shape[0] - 1):
            dist = np.max(np.abs(Xm[i+1:] - Xm[i]), axis=1)
            count += np.sum(dist <= r)
        return float(count)

    C_m = _phi(m)
    C_m1 = _phi(m + 1)
    if C_m == 0 or C_m1 == 0:
        return np.nan

    return float(-np.log(C_m1 / C_m))


def lz_complexity(signal: np.ndarray) -> float:
    """Lempel–Ziv Complexity（2値列での簡易正規化版）。"""
    x = np.array(signal, dtype=float)
    N = x.size
    if N < 10:
        return np.nan

    thr = np.median(x)
    s = ''.join('1' if v > thr else '0' for v in x)
    n = len(s)
    if n == 0:
        return np.nan

    i = 0; k = 1; l = 1; c = 1
    while True:
        if l + k > n:
            c += 1
            break
        if s[i:i + k] == s[l:l + k]:
            k += 1
            if l + k > n:
                c += 1
                break
        else:
            i += 1
            if i == l:
                c += 1
                l += k
                if l + 1 > n:
                    break
                i = 0; k = 1

    return float(c * np.log2(n) / n)


def compute_complexity_features(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info) -> Dict[str, float]:
    """複雑性：SampEn(Fz, D)＋LZC(all, F)。"""
    feats: Dict[str, float] = {}

    # SampEn: PhaseD, Fz
    tmin_D, tmax_D = PHASES["D"]
    mask_D = get_time_mask(times, tmin_D, tmax_D)
    picks_Fz = pick_by_patterns(info, ROI_PATTERNS["Fz"])
    if picks_Fz.size > 0 and np.any(mask_D):
        seg_Fz = data[picks_Fz][:, mask_D].mean(axis=0)
        feats["SampEn_Fz_phaseD"] = sample_entropy(seg_Fz, m=2, r=None) if seg_Fz.size > 50 else np.nan
    else:
        feats["SampEn_Fz_phaseD"] = np.nan

    # LZC: PhaseF, 全チャネル
    tmin_F, tmax_F = PHASES["F"]
    mask_F = get_time_mask(times, tmin_F, tmax_F)
    if np.any(mask_F):
        seg_all = data[:, mask_F].flatten(order="C")
        feats["LZC_all_phaseF"] = lz_complexity(seg_all) if seg_all.size > 50 else np.nan
    else:
        feats["LZC_all_phaseF"] = np.nan

    return feats


# ============================================================
# 5. MICROSTATES / PAC / EMBEDDING
# ============================================================

def _run_lengths(labels: np.ndarray, state: int) -> List[int]:
    """labels中でstateが連続する区間長（サンプル数）のリスト。"""
    labels = np.asarray(labels, dtype=int)
    mask = (labels == state).astype(int)
    if mask.sum() == 0:
        return []
    diffs = np.diff(mask)
    starts = np.where(diffs == 1)[0] + 1
    ends = np.where(diffs == -1)[0] + 1
    if mask[0] == 1:
        starts = np.r_[0, starts]
    if mask[-1] == 1:
        ends = np.r_[ends, len(mask)]
    return (ends - starts).tolist()


def compute_microstate_features(data_bc: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info, n_states: int = 4) -> Dict[str, float]:
    """簡易マイクロステート（C〜E, 0〜5s）。"""
    feats: Dict[str, float] = {}
    for s in range(n_states):
        s_id = s + 1
        feats[f"MS_occupancy_state{s_id}"] = np.nan
        feats[f"MS_mean_dur_state{s_id}"]  = np.nan
        feats[f"MS_mean_gfp_state{s_id}"]  = np.nan
    feats["MS_transition_rate"] = np.nan

    mask = get_time_mask(times, 0.0, 5.0)
    if not np.any(mask):
        return feats

    seg = data_bc[:, mask]
    n_ch, n_t = seg.shape
    if n_t < n_states * 5:
        return feats

    gfp = seg.std(axis=0)
    if np.allclose(gfp, 0):
        return feats

    thr = np.percentile(gfp, 70)
    high_idx = np.where(gfp >= thr)[0]
    if high_idx.size < n_states * 5:
        high_idx = np.arange(n_t)

    maps = seg[:, high_idx].T
    maps = maps - maps.mean(axis=1, keepdims=True)
    maps_norm = maps / (np.linalg.norm(maps, axis=1, keepdims=True) + 1e-12)

    try:
        centroids, _ = kmeans2(maps_norm, n_states, minit="points", iter=30)
    except Exception:
        return feats

    all_maps = seg.T
    all_maps = all_maps - all_maps.mean(axis=1, keepdims=True)
    all_maps_norm = all_maps / (np.linalg.norm(all_maps, axis=1, keepdims=True) + 1e-12)

    dists = np.sum((all_maps_norm[:, None, :] - centroids[None, :, :]) ** 2, axis=2)
    labels_all = np.argmin(dists, axis=1)

    duration_sec = n_t / sfreq
    feats["MS_transition_rate"] = float(np.sum(labels_all[1:] != labels_all[:-1]) / duration_sec)

    for s in range(n_states):
        s_id = s + 1
        mask_s = (labels_all == s)
        if not np.any(mask_s):
            continue
        feats[f"MS_occupancy_state{s_id}"] = float(mask_s.mean())
        runs = _run_lengths(labels_all, s)
        if len(runs) > 0:
            feats[f"MS_mean_dur_state{s_id}"] = float(np.mean(runs) / sfreq)
        feats[f"MS_mean_gfp_state{s_id}"] = float(gfp[mask_s].mean())

    return feats


def compute_pac_features(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info) -> Dict[str, float]:
    """Fzで θ位相×γ振幅PAC(Tort MI) を D/Eで算出。"""
    feats: Dict[str, float] = {}
    picks = pick_by_patterns(info, ROI_PATTERNS["Fz"])

    for phase_key in ["D", "E"]:
        name = f"PAC_MI_theta_gamma_Fz_phase{phase_key}"
        feats[name] = np.nan

        if picks.size == 0:
            continue

        tmin, tmax = PHASES[phase_key]
        mask = get_time_mask(times, tmin, tmax)
        if not np.any(mask):
            continue

        seg = data[picks][:, mask].mean(axis=0)
        if seg.size < int(sfreq * 0.5):
            continue

        low  = mne.filter.filter_data(seg[np.newaxis, :], sfreq, 4.0, 7.0,  verbose="ERROR")[0]
        high = mne.filter.filter_data(seg[np.newaxis, :], sfreq, 30.0, 80.0, verbose="ERROR")[0]

        phase = np.angle(hilbert(low))
        amp   = np.abs(hilbert(high))

        n_bins = 18
        bins = np.linspace(-np.pi, np.pi, n_bins + 1)
        amp_means = np.zeros(n_bins)
        for b in range(n_bins):
            idx = (phase >= bins[b]) & (phase < bins[b + 1])
            if np.any(idx):
                amp_means[b] = amp[idx].mean()

        if amp_means.sum() == 0:
            continue

        p = amp_means / amp_means.sum()
        H = -np.sum(p * np.log(p + 1e-12))
        Hmax = np.log(n_bins)
        feats[name] = float((Hmax - H) / Hmax)

    return feats


def compute_embedding_features(data: np.ndarray, times: np.ndarray, sfreq: float, info: mne.Info) -> Dict[str, float]:
    """PhaseD共分散のPCA寄与率と有効次元。"""
    feats: Dict[str, float] = {}
    mask = get_time_mask(times, *PHASES["D"])

    feats["PCA_var_ratio_pc1_phaseD"] = np.nan
    feats["PCA_var_ratio_pc2_phaseD"] = np.nan
    feats["PCA_var_ratio_pc3_phaseD"] = np.nan
    feats["PCA_var_ratio_cum3_phaseD"] = np.nan
    feats["PCA_effdim_phaseD"] = np.nan

    if not np.any(mask):
        return feats

    seg = data[:, mask].T  # (time, ch)
    if seg.shape[0] <= 3:
        return feats

    seg_centered = seg - seg.mean(axis=0, keepdims=True)
    C = np.cov(seg_centered, rowvar=False)

    evals, _ = np.linalg.eigh(C)
    evals = np.maximum(evals, 0.0)
    evals = evals[np.argsort(evals)[::-1]]

    total = evals.sum()
    if total <= 0:
        return feats

    var_ratio = evals / total
    if len(var_ratio) > 0:
        feats["PCA_var_ratio_pc1_phaseD"] = float(var_ratio[0])
    if len(var_ratio) > 1:
        feats["PCA_var_ratio_pc2_phaseD"] = float(var_ratio[1])
    if len(var_ratio) > 2:
        feats["PCA_var_ratio_pc3_phaseD"] = float(var_ratio[2])
        feats["PCA_var_ratio_cum3_phaseD"] = float(var_ratio[:3].sum())

    denom = float((evals**2).sum())
    feats["PCA_effdim_phaseD"] = float((total**2) / denom) if denom > 0 else np.nan
    return feats


# ============================================================
# 6. TRIAL-LEVEL EXTRACTION
# ============================================================

@dataclass
class TrialMeta:
    subject: str
    participant: Optional[str]
    run: str
    trial_in_run: int
    number: Optional[int]        # 音番号（48音のID）
    FileName: Optional[str]      # wav名（HCU/主観評価とJOINするキー）
    category: str
    qc_pass: bool


def build_trial_fif_path(meta: TrialMeta) -> Path:
    """
    derivatives/epochs_trial/<subject>/fif/<subject>_<run>_trialXX-epo.fif
    """
    trial_str = f"trial{meta.trial_in_run:02d}"
    fif_dir = TRIAL_BASE_DIR / meta.subject / "fif"
    fname = f"{meta.subject}_{meta.run}_{trial_str}-epo.fif"
    return fif_dir / fname


def extract_features_for_trial(meta: TrialMeta) -> Dict[str, Any]:
    """1 trial (1 FIF) から全特徴量を抽出。"""
    fif_path = build_trial_fif_path(meta)
    epochs = mne.read_epochs(fif_path, preload=True, verbose="ERROR")
    if len(epochs) != 1:
        raise ValueError(f"Expected 1 epoch in {fif_path}, got {len(epochs)}")

    data  = epochs.get_data()[0]  # (ch, time)
    times = epochs.times
    sfreq = epochs.info["sfreq"]
    info  = epochs.info

    # ERP向けベースライン補正
    data_bc = baseline_correct(data, times, tmin=-0.2, tmax=0.0)

    feats: Dict[str, Any] = {}

    # --- 統合用メタ（ここが肝）---
    feats["subject"]      = meta.subject
    feats["participant"]  = meta.participant
    feats["run"]          = meta.run
    feats["trial_in_run"] = meta.trial_in_run
    feats["number"]       = meta.number
    feats["FileName"]     = meta.FileName
    feats["category"]     = meta.category
    feats["qc_pass"]      = bool(meta.qc_pass)

    # === ERP ===
    feats["ERP_N1_FC_80_130ms"]   = erp_peak(data_bc, times, info, "FC", 0.080, 0.130, "min")
    feats["ERP_P2_FC_150_250ms"]  = erp_peak(data_bc, times, info, "FC", 0.150, 0.250, "max")
    feats["ERP_N2_FC_250_350ms"]  = erp_peak(data_bc, times, info, "FC", 0.250, 0.350, "min")
    feats["ERP_P3_P_300_500ms"]   = erp_peak(data_bc, times, info, "P",  0.300, 0.500, "max")
    feats["ERP_LPP_P_400_800ms"]  = erp_peak(data_bc, times, info, "P",  0.400, 0.800, "mean")

    # オフセット（音終了=5s想定）
    feats["ERP_P2off_FC_0_300ms_postOff"]  = erp_peak(data_bc, times, info, "FC", 5.0, 5.3, "max")
    feats["ERP_LPPoff_P_300_800ms_postOff"] = erp_peak(data_bc, times, info, "P",  5.3, 5.8, "mean")

    # === SPN / 持続 ===
    feats["SPN_Cz_pre"] = sustained_potential(data_bc, times, info, "Cz", *PHASES["B"])
    feats["SPN_Pz_pre"] = sustained_potential(data_bc, times, info, "Pz", *PHASES["B"])

    feats["SP_D_Cz"] = sustained_potential(data_bc, times, info, "Cz", *PHASES["D"])
    feats["SP_D_Pz"] = sustained_potential(data_bc, times, info, "Pz", *PHASES["D"])
    feats["SP_E_Cz"] = sustained_potential(data_bc, times, info, "Cz", *PHASES["E"])
    feats["SP_E_Pz"] = sustained_potential(data_bc, times, info, "Pz", *PHASES["E"])
    feats["SP_F_Cz"] = sustained_potential(data_bc, times, info, "Cz", *PHASES["F"])
    feats["SP_F_Pz"] = sustained_potential(data_bc, times, info, "Pz", *PHASES["F"])

    # === バンドパワー/ERD ===
    bp_baseline: Dict[tuple, float] = {}
    baseline_roi_band = [("Fz", "theta"), ("Pz", "alpha"), ("OCC", "alpha"), ("Cz", "beta")]
    for roi_key, band_name in baseline_roi_band:
        bp_baseline[(roi_key, band_name)] = bandpower_roi(data, times, sfreq, info, roi_key, band_name, *PHASES["A"])

    for phase_key in ["C", "D", "E", "F"]:
        tmin, tmax = PHASES[phase_key]

        bp = bandpower_roi(data, times, sfreq, info, "Fz", "theta", tmin, tmax)
        feats[f"BP_theta_Fz_phase{phase_key}"]  = bp
        feats[f"ERD_theta_Fz_phase{phase_key}"] = erd_ers_db(bp, bp_baseline[("Fz", "theta")])

        bp = bandpower_roi(data, times, sfreq, info, "Pz", "alpha", tmin, tmax)
        feats[f"BP_alpha_Pz_phase{phase_key}"]  = bp
        feats[f"ERD_alpha_Pz_phase{phase_key}"] = erd_ers_db(bp, bp_baseline[("Pz", "alpha")])

        bp = bandpower_roi(data, times, sfreq, info, "OCC", "alpha", tmin, tmax)
        feats[f"BP_alpha_Occ_phase{phase_key}"]  = bp
        feats[f"ERD_alpha_Occ_phase{phase_key}"] = erd_ers_db(bp, bp_baseline[("OCC", "alpha")])

        bp = bandpower_roi(data, times, sfreq, info, "Cz", "beta", tmin, tmax)
        feats[f"BP_beta_Cz_phase{phase_key}"]  = bp
        feats[f"ERD_beta_Cz_phase{phase_key}"] = erd_ers_db(bp, bp_baseline[("Cz", "beta")])

    # FAA(B〜F)
    for phase_key in ["B", "C", "D", "E", "F"]:
        feats[f"FAA_alpha_phase{phase_key}"] = faa_alpha(data, times, sfreq, info, *PHASES[phase_key])

    # 高次特徴
    feats.update(compute_riemann_features(data, times, sfreq, info))
    feats.update(compute_connectivity_features(data, times, sfreq, info))
    feats.update(compute_complexity_features(data, times, sfreq, info))
    feats.update(compute_microstate_features(data_bc, times, sfreq, info))
    feats.update(compute_pac_features(data, times, sfreq, info))
    feats.update(compute_embedding_features(data, times, sfreq, info))

    return feats


# ============================================================
# 7. INDEX LOADING（統合用に “number/FileName” を必ず拾う）
# ============================================================

def load_epoch_index() -> pd.DataFrame:
    """
    master_epoch_index.csv を読み込み、
    subject/run/trial_in_run/qc_pass/category/number/FileName を可能な限り揃える。
    """
    if not MASTER_INDEX_CSV.exists():
        raise FileNotFoundError(f"NOT FOUND: {MASTER_INDEX_CSV}")

    df = pd.read_csv(MASTER_INDEX_CSV)
    cols = list(df.columns)

    def find_col(candidates: List[str]) -> Optional[str]:
        for c in candidates:
            if c in cols:
                return c
        return None

    # --- 必須キー ---
    subj_col = find_col(["subject", "subj", "subj_id", "participant", "被験者", "Subject"])
    run_col  = find_col(["run", "Run", "run_id", "session", "RunNo", "run_no"])
    tri_col  = find_col(["trial_in_run", "trial", "Trial", "trial_index", "trial_idx", "trial_no", "trial_number"])

    if subj_col is None or run_col is None or tri_col is None:
        raise ValueError(
            "master_epoch_index.csv に必要列が足りません。\n"
            f"  subject候補={subj_col}, run候補={run_col}, trial候補={tri_col}\n"
            f"  columns={cols}"
        )

    # --- できれば欲しい（統合の要）---
    num_col = find_col(["number", "No", "sound_no", "sound_number", "sound_id"])
    fn_col  = find_col(["FileName", "filename", "file_name", "wav", "sound_name"])

    cat_col = find_col(["category", "カテゴリ", "カテゴリー", "valence_group", "label"])
    qc_col  = find_col(["qc_pass", "qc_ok", "use", "valid", "qc_amp_pass"])

    # 正規化rename
    ren = {subj_col: "subject", run_col: "run", tri_col: "trial_in_run"}
    if num_col is not None: ren[num_col] = "number"
    if fn_col  is not None: ren[fn_col]  = "FileName"
    if cat_col is not None: ren[cat_col] = "category"
    if qc_col  is not None: ren[qc_col]  = "qc_pass"
    df = df.rename(columns=ren)

    # 型・正規化
    df["subject"] = df["subject"].astype(str)
    df["run"] = df["run"].apply(normalize_run)
    df["trial_in_run"] = pd.to_numeric(df["trial_in_run"], errors="coerce").astype("Int64")

    if "qc_pass" not in df.columns:
        df["qc_pass"] = True
    df["qc_pass"] = df["qc_pass"].fillna(False).astype(bool)

    if "category" not in df.columns:
        df["category"] = "unknown"
    df["category"] = df["category"].astype(str)

    if "number" in df.columns:
        df["number"] = pd.to_numeric(df["number"], errors="coerce").astype("Int64")
    else:
        df["number"] = pd.Series([pd.NA] * len(df), dtype="Int64")

    # --- FileName 正規化（NAを壊さない）---
    if "FileName" in df.columns:
        df["FileName_raw"] = df["FileName"].astype("string")
        df["FileName"] = df["FileName_raw"].apply(canonicalize_filename).astype("string")

        changed = df["FileName_raw"].notna() & (df["FileName_raw"] != df["FileName"])
        if changed.any():
            df.loc[changed, ["subject", "run", "trial_in_run", "number", "FileName_raw", "FileName"]].head(500).to_csv(
                AUDIT_DIR / "master_epoch_index_filename_normalized_head500.csv",
                index=False, encoding="utf-8-sig"
            )
            print(f"[INFO] FileName正規化あり -> {AUDIT_DIR/'master_epoch_index_filename_normalized_head500.csv'}")
    else:
        df["FileName_raw"] = pd.Series([pd.NA] * len(df), dtype="string")
        df["FileName"] = pd.Series([pd.NA] * len(df), dtype="string")



    # participant付与（主観評価Pxxと統合しやすく）
    df["participant"] = df["subject"].apply(subject_to_participant)

    # キー重複チェック（統合の安定性）
    key = ["subject", "run", "trial_in_run"]
    dup = df.duplicated(subset=key, keep=False).sum()
    if dup > 0:
        ex = df[df.duplicated(subset=key, keep=False)].sort_values(key).head(40)
        ex.to_csv(AUDIT_DIR / "master_epoch_index_duplicate_keys.csv", index=False, encoding="utf-8-sig")
        raise ValueError(
            f"master_epoch_index のキーが重複しています: dup_rows={dup}\n"
            f"監査ログ: {AUDIT_DIR/'master_epoch_index_duplicate_keys.csv'}"
        )

    # number/FileName 欠損監査（ある程度は許容するが、sound集約は危険）
    miss_n = int(df["number"].isna().sum())
    miss_f = int(df["FileName"].isna().sum())
    if miss_n > 0 or miss_f > 0:
        df[df["number"].isna() | df["FileName"].isna()][key + ["number", "FileName"]].head(300).to_csv(
            AUDIT_DIR / "master_epoch_index_missing_sound_identity_head300.csv",
            index=False, encoding="utf-8-sig"
        )
        print(f"[WARN] master_epoch_index に number/FileName 欠損があります: number_missing={miss_n}, FileName_missing={miss_f}")
        print(f"       audit -> {AUDIT_DIR/'master_epoch_index_missing_sound_identity_head300.csv'}")

    # --- number -> FileName が 1対1 になっているか（JOINの健全性チェック）---
    sub = df.dropna(subset=["number", "FileName"]).copy()

    if sub.empty:
        print("[WARN] number/FileName が全て欠損のため、1対1チェックをスキップします。")
    else:
        nu = sub.groupby("number")["FileName"].nunique(dropna=True)
        bad = nu[nu > 1]

    if not bad.empty:
        detail = sub[sub["number"].isin(bad.index)][
            ["number", "FileName", "subject", "run", "trial_in_run"]
        ].sort_values(["number", "FileName", "subject", "run", "trial_in_run"])

        out_path = AUDIT_DIR / "master_epoch_index_number_to_multiple_filenames.csv"
        detail.to_csv(out_path, index=False, encoding="utf-8-sig")

        raise ValueError(
            f"number→FileName が 1対1 になっていません（{bad.size} numbers）\n"
            f"audit -> {out_path}"
        )
    else:
        print("[OK] number→FileName は 1対1 です（正規化後）。")


    return df


def build_trial_meta_list(df_index: pd.DataFrame) -> List[TrialMeta]:
    """master index DataFrame から TrialMeta のリストを生成。"""
    metas: List[TrialMeta] = []
    for _, row in df_index.iterrows():
        metas.append(
            TrialMeta(
                subject      = str(row["subject"]),
                participant  = row.get("participant", None),
                run          = str(row["run"]),
                trial_in_run = int(row["trial_in_run"]),
                number       = (None if pd.isna(row["number"]) else int(row["number"])),
                FileName     = (None if pd.isna(row["FileName"]) else str(row["FileName"])),
                category     = str(row["category"]),
                qc_pass      = bool(row["qc_pass"]),
            )
        )
    return metas


# ============================================================
# 8. TABLE BUILDING（trial → sound → subject）
# ============================================================

def build_trial_table() -> pd.DataFrame:
    """
    全trialについて特徴量を計算して trial_table を返す。
    ★エラーtrialも残す（feature_ok=False）ことで、後段JOINが安定する。
    """
    df_index = load_epoch_index()
    metas = build_trial_meta_list(df_index)

    total = len(metas)
    n_target = sum(1 for m in metas if (m.qc_pass or not PROCESS_ONLY_QC_PASS))
    print(f"Total trials in master_epoch_index: {total}")
    print(f"Trials to process (qc filter={PROCESS_ONLY_QC_PASS}): {n_target}")

    rows: List[Dict[str, Any]] = []

    import sys
    done = 0
    for meta in metas:
        # A運用：qc_fail は“行は残すが計算はしない”
        if PROCESS_ONLY_QC_PASS and (not meta.qc_pass):
            rows.append(
                {
                    "subject": meta.subject,
                    "participant": meta.participant,
                    "run": meta.run,
                    "trial_in_run": meta.trial_in_run,
                    "number": meta.number,
                    "FileName": meta.FileName,
                    "category": meta.category,
                    "qc_pass": meta.qc_pass,
                    "feature_ok": False,
                    "error_msg": "qc_fail_skipped",
                }
            )
            continue

        done += 1
        try:
            feats = extract_features_for_trial(meta)
            feats["feature_ok"] = True
            feats["error_msg"] = ""
        except Exception as e:
            if not KEEP_ERROR_ROWS:
                sys.stdout.write("\n")
                print(f"[ERROR {done}/{n_target}] {meta.subject}-{meta.run}-trial{meta.trial_in_run:02d}: {e}")
                continue

            feats = {
                "subject": meta.subject,
                "participant": meta.participant,
                "run": meta.run,
                "trial_in_run": meta.trial_in_run,
                "number": meta.number,
                "FileName": meta.FileName,
                "category": meta.category,
                "qc_pass": meta.qc_pass,
                "feature_ok": False,
                "error_msg": str(e),
            }

        rows.append(feats)

        sys.stdout.write(
            f"\rProcessing trials: {done}/{n_target} ({meta.subject}-{meta.run}-trial{meta.trial_in_run:02d})"
        )
        sys.stdout.flush()


    print()
    trial_table = pd.DataFrame(rows)
    print("Finished trial_table. Shape:", trial_table.shape)

    # エラー監査
    n_bad = int((trial_table["feature_ok"] == False).sum()) if "feature_ok" in trial_table.columns else 0
    if n_bad > 0:
        trial_table[trial_table["feature_ok"] == False].head(300).to_csv(
            AUDIT_DIR / "eeg_feature_errors_head300.csv", index=False, encoding="utf-8-sig"
        )
        print(f"[WARN] feature_ok=False が {n_bad} 件あります。audit -> {AUDIT_DIR/'eeg_feature_errors_head300.csv'}")

    return trial_table


def aggregate_sound_table(trial_table: pd.DataFrame) -> Optional[pd.DataFrame]:
    """
    sound(number×FileName)単位で平均＋SEを集約。
    ★number/FileName が揃っていないと “音” 集約にならないので、その場合はNoneを返す。
    """ 
    trial_table = trial_table[trial_table["feature_ok"] == True].copy()
    if ("number" not in trial_table.columns) or ("FileName" not in trial_table.columns):
        print("[WARN] trial_table に number/FileName が無いので sound-level 集約をスキップします。")
        return None

    # 欠損が多い場合も危険（音の同定が曖昧）
    miss = trial_table["number"].isna().sum() + trial_table["FileName"].isna().sum()
    if miss > 0:
        print("[WARN] number/FileName 欠損行があるため sound-level は不完全になる可能性があります。")

    # 数値列のうち、メタ（ID系）を除外して特徴量だけにする
    meta_numeric = {"trial_in_run", "number"}  # ← numberも除外
    numeric_cols = [c for c in trial_table.select_dtypes(include=[np.number]).columns if c not in meta_numeric]

    group_keys = ["number", "FileName"]
    g = trial_table.groupby(group_keys)[numeric_cols]

    mean_df = g.mean().add_suffix("_mean")
    se_df   = (g.std() / np.sqrt(g.count())).add_suffix("_se")

    out = pd.concat([mean_df, se_df], axis=1).reset_index()
    return out


def aggregate_subject_table(trial_table: pd.DataFrame) -> pd.DataFrame:
    """
    participant(Pxx)単位で平均。
    participant が作れない場合は subject 単位にフォールバック。
    """
    trial_table = trial_table[trial_table["feature_ok"] == True].copy()
    id_col = "participant" if "participant" in trial_table.columns and trial_table["participant"].notna().any() else "subject"

    numeric_cols = trial_table.select_dtypes(include=[np.number]).columns
    # ID数値(trial_in_runなど)は外す
    drop_numeric = {"trial_in_run"}
    numeric_cols = [c for c in numeric_cols if c not in drop_numeric]

    g = trial_table.groupby(id_col)[numeric_cols]
    out = g.mean().add_suffix("_mean").reset_index().rename(columns={id_col: "participant" if id_col=="participant" else "subject"})
    return out


# ============================================================
# 9. MAIN
# ============================================================

def main():
    print("ROOT_DIR:", ROOT_DIR)
    print("MASTER_INDEX_CSV:", MASTER_INDEX_CSV)
    print("TRIAL_BASE_DIR:", TRIAL_BASE_DIR)

    # 1) trial
    print("\nBuilding trial_table ...")
    trial_table = build_trial_table()
    OUT_TRIAL_CSV.parent.mkdir(parents=True, exist_ok=True)
    trial_table.to_csv(OUT_TRIAL_CSV, index=False, encoding="utf-8-sig")
    print(f"Saved trial-level features -> {OUT_TRIAL_CSV}")

    # 2) sound（可能な場合のみ）
    print("\nAggregating sound_table ...")
    sound_table = aggregate_sound_table(trial_table)
    if sound_table is not None:
        sound_table.to_csv(OUT_SOUND_CSV, index=False, encoding="utf-8-sig")
        print(f"Saved sound-level features -> {OUT_SOUND_CSV}")
        print("  n_sounds:", len(sound_table))
    else:
        print("[INFO] sound-level features は未出力（number/FileName不足のため）")

    # 3) subject/participant
    print("\nAggregating subject_table ...")
    subj_table = aggregate_subject_table(trial_table)
    subj_table.to_csv(OUT_SUBJ_CSV, index=False, encoding="utf-8-sig")
    print(f"Saved subject-level features -> {OUT_SUBJ_CSV}")
    print("  n_subjects:", len(subj_table))

    print("\n=== DONE ===")
    print("trial :", OUT_TRIAL_CSV)
    print("sound :", OUT_SOUND_CSV if sound_table is not None else "(skipped)")
    print("subj  :", OUT_SUBJ_CSV)
    print("audit :", AUDIT_DIR)


if __name__ == "__main__":
    main()


ROOT_DIR: /Users/shunsuke/EEG_48sounds
MASTER_INDEX_CSV: /Users/shunsuke/EEG_48sounds/derivatives/master_epoch_index.csv
TRIAL_BASE_DIR: /Users/shunsuke/EEG_48sounds/derivatives/epochs_trial

Building trial_table ...
[INFO] FileName正規化あり -> /Users/shunsuke/EEG_48sounds/output/integration_audit/master_epoch_index_filename_normalized_head500.csv
[OK] number→FileName は 1対1 です（正規化後）。
Total trials in master_epoch_index: 1728
Trials to process (qc filter=True): 1588
Processing trials: 8/1588 (01_高見-run1-trial08)

  return fun(*args, **kwargs)


Processing trials: 24/1588 (01_高見-run1-trial24)

  return fun(*args, **kwargs)


Processing trials: 52/1588 (01_高見-run2-trial04)

  return fun(*args, **kwargs)


Processing trials: 61/1588 (01_高見-run2-trial13)

  return fun(*args, **kwargs)


Processing trials: 64/1588 (01_高見-run2-trial16)

  return fun(*args, **kwargs)


Processing trials: 69/1588 (01_高見-run2-trial21)

  return fun(*args, **kwargs)


Processing trials: 74/1588 (01_高見-run2-trial26)

  return fun(*args, **kwargs)


Processing trials: 85/1588 (01_高見-run2-trial37)

  return fun(*args, **kwargs)


Processing trials: 87/1588 (01_高見-run2-trial39)

  return fun(*args, **kwargs)


Processing trials: 92/1588 (01_高見-run2-trial44)

  return fun(*args, **kwargs)


Processing trials: 96/1588 (01_高見-run2-trial48)

  return fun(*args, **kwargs)


Processing trials: 97/1588 (01_高見-run3-trial01)

  return fun(*args, **kwargs)


Processing trials: 99/1588 (01_高見-run3-trial03)

  return fun(*args, **kwargs)


Processing trials: 109/1588 (01_高見-run3-trial13)

  return fun(*args, **kwargs)


Processing trials: 117/1588 (01_高見-run3-trial21)

  return fun(*args, **kwargs)


Processing trials: 127/1588 (01_高見-run3-trial31)

  return fun(*args, **kwargs)


Processing trials: 300/1588 (03_江口-run1-trial29)

  return fun(*args, **kwargs)


Processing trials: 301/1588 (03_江口-run1-trial30)

  return fun(*args, **kwargs)


Processing trials: 322/1588 (03_江口-run2-trial05)

  return fun(*args, **kwargs)


Processing trials: 340/1588 (03_江口-run2-trial24)

  return fun(*args, **kwargs)


Processing trials: 346/1588 (03_江口-run2-trial30)

  return fun(*args, **kwargs)


Processing trials: 358/1588 (03_江口-run2-trial42)

  return fun(*args, **kwargs)


Processing trials: 364/1588 (03_江口-run2-trial48)

  return fun(*args, **kwargs)


Processing trials: 380/1588 (03_江口-run3-trial19)

  return fun(*args, **kwargs)


Processing trials: 382/1588 (03_江口-run3-trial21)

  return fun(*args, **kwargs)


Processing trials: 390/1588 (03_江口-run3-trial29)

  return fun(*args, **kwargs)


Processing trials: 393/1588 (03_江口-run3-trial32)

  return fun(*args, **kwargs)


Processing trials: 1294/1588 (10_細川-run3-trial19)

  return fun(*args, **kwargs)


Processing trials: 1588/1588 (12_濱津-run3-trial48)
Finished trial_table. Shape: (1728, 1194)
[WARN] feature_ok=False が 140 件あります。audit -> /Users/shunsuke/EEG_48sounds/output/integration_audit/eeg_feature_errors_head300.csv
Saved trial-level features -> /Users/shunsuke/EEG_48sounds/derivatives/eeg_features_trial.csv

Aggregating sound_table ...
Saved sound-level features -> /Users/shunsuke/EEG_48sounds/derivatives/eeg_features_sound.csv
  n_sounds: 48

Aggregating subject_table ...
Saved subject-level features -> /Users/shunsuke/EEG_48sounds/derivatives/eeg_features_subject.csv
  n_subjects: 12

=== DONE ===
trial : /Users/shunsuke/EEG_48sounds/derivatives/eeg_features_trial.csv
sound : /Users/shunsuke/EEG_48sounds/derivatives/eeg_features_sound.csv
subj  : /Users/shunsuke/EEG_48sounds/derivatives/eeg_features_subject.csv
audit : /Users/shunsuke/EEG_48sounds/output/integration_audit
