In [None]:
#!/usr/bin/env python3
"""
erp_viz.py

Unified ERP Visualization

This script computes event-related potentials (ERPs) from time-domain EEG CSV exports,
for both simulated and real-world driving hazard detection tasks. It averages across
expert and novice drivers, applies a custom head montage, and produces publication-quality
ERP line plots with highlighted temporal windows.

Usage:
    python erp_viz.py \
        --set-ref-dir /path/to/eeglab/sets \
        --set-ref-file subject01.set \
        --csv-dir /path/to/csv_subfolders \
        --output-dir /path/to/save/figures \
        --envs v r           \
        --n-expert 36        \
        --n-novice 64        \
        --tmin -200          \
        --tmax 600           \
        --sfreq 1000
"""

import os
import argparse

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import mne
from mne.filter import filter_data


def parse_args():
    """Parse command‑line inputs."""
    p = argparse.ArgumentParser(description="ERP visualization for v/r driving data")
    p.add_argument("--set-ref-dir",  required=True,
                   help="Directory of EEGLAB .set files")
    p.add_argument("--set-ref-file", required=True,
                   help="EEGLAB .set filename for channel montage")
    p.add_argument("--csv-dir",      required=True,
                   help="Directory containing subject subfolders of CSV data")
    p.add_argument("--output-dir",   required=True,
                   help="Where to write ERP figures")
    p.add_argument("--envs", nargs="+", choices=["v", "r"], default=["v"],
                   help="Environment codes: 'v' (virtual), 'r' (real)")
    p.add_argument("--n-expert", type=int, default=36,
                   help="Number of expert subjects")
    p.add_argument("--n-novice", type=int, default=64,
                   help="Number of novice subjects")
    p.add_argument("--tmin", type=int, default=-200,
                   help="Epoch start time (ms)")
    p.add_argument("--tmax", type=int, default=600,
                   help="Epoch end time (ms)")
    p.add_argument("--sfreq", type=float, default=1000.0,
                   help="Sampling frequency of CSV time series (Hz)")
    p.add_argument("--drop-channels", nargs="+",
                   default=['HEO','VEO','CB1','CB2','EKG','EMG','TRIGGER','M1','M2'],
                   help="Auxiliary channels to exclude")
    return p.parse_args()


def rename_channel(raw: str) -> str:
    """
    Standardize a channel label to Title case.
    e.g. 'FC1' → 'Fc1'
    """
    nm = raw.strip().upper()
    return nm[0] + nm[1:].lower() if nm else nm


def is_aux(ch: str, drop_list: list) -> bool:
    """Check if a channel is auxiliary (in drop_list)."""
    return ch.upper() in (d.upper() for d in drop_list)


def extract_montage(set_dir: str, set_file: str, drop_list: list):
    """
    Load an EEGLAB .set, rename non‑aux channels, and build an MNE montage.
    Returns (montage, channel_names).
    """
    path = os.path.join(set_dir, set_file)
    raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)

    # Rename channels consistently
    mapping = {old: rename_channel(old) for old in raw.ch_names}
    raw.rename_channels(mapping)

    # Exclude auxiliary channels
    keep = [ch for ch in raw.ch_names if not is_aux(ch, drop_list)]
    raw.pick(keep)

    # Build montage from 3D loc vectors
    pos = {ch: raw.info['chs'][raw.ch_names.index(ch)]['loc'][:3]
           for ch in raw.ch_names}
    montage = mne.channels.make_dig_montage(ch_pos=pos, coord_frame='head')
    return montage, raw.ch_names


def load_single_csv(path: str, tmin: int, tmax: int,
                    sfreq: float, drop_list: list):
    """
    Load one time‑series CSV:
      - first row: times (s), first col header
      - first col: channel names, first row header
    Returns (data_filtered [n_chan×n_time], chan_names, times_ms).
    """
    df = pd.read_csv(path, header=None).dropna(how='all', axis=[0,1])
    if df.shape[0] < 2 or df.shape[1] < 2:
        return None, None, None

    # Extract and select time window
    times = df.iloc[0, 1:].astype(float).values * 1000  # ms
    mask = (times >= tmin) & (times <= tmax)
    idx = np.where(mask)[0]
    if not len(idx):
        return None, None, None
    sel_times = times[idx]

    # Extract and filter channel data
    raw_vals = df.iloc[1:, 1:].values.astype(float)
    raw_names = df.iloc[1:, 0].values
    keep_rows, names = [], []
    for i, ch in enumerate(raw_names):
        std = rename_channel(ch)
        if not is_aux(std, drop_list):
            keep_rows.append(i)
            names.append(std)
    mat = raw_vals[keep_rows][:, idx]
    filt = filter_data(mat, sfreq, l_freq=None, h_freq=None,
                       method='fir', fir_design='firwin', verbose=False)
    return filt, names, sel_times


def load_epochs(folder: str, args):
    """
    Read all CSVs in `folder`, convert each to [n_chan×n_time],
    and stack into an epochs array [n_epochs×n_chan×n_time].
    """
    files = sorted(f for f in os.listdir(folder) if f.lower().endswith('.csv'))
    mats, ch0, t0 = [], None, None
    for f in files:
        mat, chs, ts = load_single_csv(os.path.join(folder, f),
                                       args.tmin, args.tmax,
                                       args.sfreq, args.drop_channels)
        if mat is None:
            continue
        if ch0 is None:
            ch0, t0 = chs, ts
        elif chs != ch0 or len(ts) != len(t0):
            continue
        mats.append(mat[np.newaxis])
    if not mats:
        return None, None, None
    return np.concatenate(mats, axis=0), ch0, t0


def load_subjects(base: str, driver: str, n_subj: int,
                  cond: str, args):
    """
    For driver in {'e','n'} and condition code (e.g. 'v_c'),
    average each subject’s epochs → returns [n_subj×n_chan×n_time].
    """
    arrs, chs, times = [], None, None
    for sid in range(1, n_subj + 1):
        fd = f"{driver}_{sid:02d}_{cond}"
        fp = os.path.join(base, fd)
        if not os.path.isdir(fp):
            continue
        ep, c, t = load_epochs(fp, args)
        if ep is None:
            continue
        avg = np.nanmean(ep, axis=0)
        if chs is None:
            chs, times = c, t
        elif c != chs or len(t) != len(times):
            continue
        arrs.append(avg[np.newaxis])
    if not arrs:
        return None, None, None
    # Align time‑length across subjects
    min_t = min(a.shape[2] for a in arrs)
    arrs = [a[:, :, :min_t] for a in arrs]
    return np.concatenate(arrs, axis=0), chs, times[:min_t]


def make_evoked(data3d, chs, times, sfreq, montage):
    """
    Collapse [n_subj×n_chan×n_time] → MNE EvokedArray,
    to leverage built‑in plotting and channel verification.
    """
    avg2d = np.nanmean(data3d, axis=0)
    info = mne.create_info(chs, sfreq=sfreq, ch_types='eeg')
    ev = mne.EvokedArray(avg2d, info, tmin=times[0]/1000.0)
    ev.set_montage(montage, on_missing='ignore')
    return ev


def run_visualization(args):
    """Main workflow: montage, data loading, ERP computation, plotting."""
    os.makedirs(args.output_dir, exist_ok=True)

    # Map driver & condition codes to readable labels
    drivers = {'e': 'Expert', 'n': 'Novice'}
    conditions = {'c': 'Control', 'h': 'Overt Hazard', 'o': 'Covert Hazard'}

    # Color scheme for each label
    colors = {
        "Expert Control":       "#1F77B4",
        "Expert Overt Hazard":  "#6699CC",
        "Expert Covert Hazard": "#99BBDD",
        "Novice Control":       "#FF7F0E",
        "Novice Overt Hazard":  "#FFAA66",
        "Novice Covert Hazard": "#FFCCAA",
    }

    # Temporal windows to highlight (ms), keyed by channel
    # Example annotations: P2 (200–300 ms), P3a (300–400 ms), N1 (130–150 ms), etc.
    intervals_map = {
        'v': {  # virtual driving data
            'O1':  [
                (200, 300),  # P2 component: 200–300 ms
                (300, 400)   # P3a component: 300–400 ms
            ],
            'FC1': [
                (130, 150)   # N1 component: 130–150 ms
            ],
            'FT8': [
                (130, 150)   # N1 component: 130–150 ms
            ],
            'CPZ': [
                (100, 120),  # early N1: 100–120 ms
                (350, 550)   # P3b component: 350–550 ms
            ],
            'AF3': [
                (100, 130)   # N1 component: 100–130 ms
            ],
            'FP2': [
                (100, 130)   # N1 component: 100–130 ms
            ],
        },
        'r': {  # real‑world driving data
            'O1':  [
                (90, 130),   # N1 component: 90–130 ms
                (400, 550)   # P3b component: 400–550 ms
            ],
            'FC1': [
                (250, 350)   # N2/P3 transition: 250–350 ms
            ],
            'AF3': [
                (130, 150),  # early frontal: 130–150 ms
                (250, 350)   # later frontal: 250–350 ms
            ],
            'F7':  [
                (130, 150),  # early F7 response: 130–150 ms
                (250, 350)   # late F7 response: 250–350 ms
            ],
            'FP1': [
                (130, 160),  # early FP1 frontal: 130–160 ms
                (350, 450)   # late FP1 frontal: 350–450 ms
            ],
        }
    }


    # Channels of interest for each environment
    channels_map = {
        'v': ['O1','FC1','FT8','CPz','Af3','Fp2'],
        'r': ['O1','FC1','AF3','F7','FP1']
    }

    # Build montage once
    montage, _ = extract_montage(
        args.set_ref_dir, args.set_ref_file, args.drop_channels
    )

    # Iterate selected environments
    for env in args.envs:
        prefix = env  # 'v' or 'r'
        evokeds = {}

        # Load and compute Evoked per driver & condition
        for dcode, dlabel in drivers.items():
            nsubj = args.n_expert if dcode=='e' else args.n_novice
            for ccode, clabel in conditions.items():
                cond = f"{prefix}_{ccode}"
                data3d, chs, ts = load_subjects(
                    args.csv_dir, dcode, nsubj, cond, args
                )
                if data3d is None:
                    continue
                label = f"{dlabel} {clabel}"
                ev = make_evoked(data3d, chs, ts, args.sfreq, montage)
                ev.comment = label
                evokeds[label] = ev

        # Plot each channel of interest
        for chan in channels_map[prefix]:
            df_list = []
            chan_up = chan.upper()
            if chan_up not in intervals_map[prefix]:
                continue

            # Assemble DataFrame across conditions
            for label, ev in evokeds.items():
                try:
                    idx = next(i for i, ch in enumerate(ev.ch_names)
                               if ch.upper()==chan_up)
                except StopIteration:
                    continue
                df_list.append(pd.DataFrame({
                    'time_ms': ev.times*1000,
                    'amplitude': ev.data[idx]*1e6,
                    'condition': label
                }))
            if not df_list:
                continue

            df = pd.concat(df_list, ignore_index=True)
            plt.figure(figsize=(6,4), dpi=120)
            sns.set_theme(style='ticks')
            fig, ax = plt.subplots(figsize=(6,4), dpi=120)

            # Plot each condition
            for cond_label in df.condition.unique():
                sub = df[df.condition==cond_label]
                ax.plot(sub.time_ms, sub.amplitude,
                        label=cond_label,
                        color=colors.get(cond_label, 'gray'),
                        linewidth=2)

            # Baseline at 0 μV
            ax.axhline(0, color='black', linestyle='--', linewidth=1)

            # Highlight temporal windows
            for (start, end) in intervals_map[prefix][chan_up]:
                ax.axvspan(start, end, color='red', alpha=0.15)

            # Labels and legends
            ax.set(
                title=f"{chan} ERP ({'Virtual' if prefix=='v' else 'Real'})",
                xlabel="Time (ms)",
                ylabel="Amplitude (µV)"
            )
            ax.legend(frameon=True, fontsize='small')
            plt.tight_layout()

            # Save outputs
            stem = f"ERP_{prefix.upper()}_{chan_up}"
            fig.savefig(os.path.join(args.output_dir, stem + ".png"),
                        dpi=300, bbox_inches='tight')
            fig.savefig(os.path.join(args.output_dir, stem + ".svg"),
                        dpi=300, bbox_inches='tight')
            plt.close(fig)


if __name__ == "__main__":
    args = parse_args()
    run_visualization(args)
