In [None]:
#!/usr/bin/env python3
"""
freq_topomap_real.py

Publication‑quality frequency‑domain topographic maps for the **real-world**
driving environment. Computes group‑mean power in dB for expert and novice
subjects within specified frequency bands and time windows, then visualizes:

  1. Expert group topomap (white→blue)
  2. Novice group topomap
  3. Difference topomap (Expert − Novice) with orange→white→blue diverging map

Difference logic:
  ΔdB = mean_expert_dB − mean_novice_dB per electrode;
  red/blue diverging map: red=expert>novice, blue=novice>expert, white=equal.

Usage:
    python freq_topomap_real.py \
        --set-ref-dir /path/to/eeglab/sets \
        --set-ref-file subject01.set \
        --csv-dir /path/to/freq_csv/output \
        --output-dir /path/to/save/figures \
        --conditions h o c   \
        --bands delta theta beta gamma1 \
        --n-expert 36        \
        --n-novice 64        \
        --contrast 1.5       \
        --percentile 98      \
        --t-range 50-550 150-250
"""
import os
import argparse
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt
from mne.channels import layout
from matplotlib.colors import LinearSegmentedColormap

# Constants
CHANNELS_TO_DROP = ['HEO','VEO','CB1','CB2','EKG','EMG','TRIGGER','M1','M2']
ENV = 'r'  # real environment code

# ----------------------------------------------------------------------------
# Argument Parsing
# ----------------------------------------------------------------------------
def parse_args():
    p = argparse.ArgumentParser(
        description="Frequency‑domain topomap visualization for real-world driving data"
    )
    p.add_argument('--set-ref-dir',  required=True,
                   help='Directory of EEGLAB .set files')
    p.add_argument('--set-ref-file', required=True,
                   help='.set filename')
    p.add_argument('--csv-dir',      required=True,
                   help='Root directory of frequency CSV subfolders')
    p.add_argument('--output-dir',   required=True,
                   help='Directory to save topomap figures')
    p.add_argument('--conditions', nargs='+', choices=['h','o','c'], required=True,
                   help='Condition codes: h (hazard), o (occlusion), c (control)')
    p.add_argument('--bands', nargs='+', choices=['delta','theta','beta','gamma1'], required=True,
                   help='Frequency bands to plot')
    p.add_argument('--n-expert', type=int, default=36,
                   help='Number of expert subjects')
    p.add_argument('--n-novice', type=int, default=64,
                   help='Number of novice subjects')
    p.add_argument('--contrast', type=float, default=1.5,
                   help='Contrast scaling factor for dB')
    p.add_argument('--percentile', type=int, default=98,
                   help='Percentile for sequential vmin/vmax')
    p.add_argument('--t-range', dest='t_ranges', nargs='+', required=True,
                   help='Time windows (ms), e.g. 150-550 50-350')
    return p.parse_args()

# ----------------------------------------------------------------------------
# Utility Functions
# ----------------------------------------------------------------------------
def rename_channel(raw: str) -> str:
    return raw.strip().upper()[0] + raw.strip().upper()[1:].lower()

def is_aux(ch: str) -> bool:
    return ch.upper() in (c.upper() for c in CHANNELS_TO_DROP)

def extract_ch_name(raw: str) -> str:
    return rename_channel(raw.split()[0])

def extract_montage(set_dir: str, set_file: str):
    path = os.path.join(set_dir, set_file)
    raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
    raw.rename_channels({ch: rename_channel(ch) for ch in raw.ch_names})
    keep = [ch for ch in raw.ch_names if not is_aux(ch)]
    raw.pick(keep)
    pos = {ch: raw.info['chs'][idx]['loc'][:3]
           for idx, ch in enumerate(raw.ch_names)}
    montage = mne.channels.make_dig_montage(ch_pos=pos, coord_frame='head')
    return montage, raw.ch_names

def load_single_csv(path: str):
    df = pd.read_csv(path, header=None).dropna(how='all', axis=[0,1])
    if df.shape[0]<2 or df.shape[1]<2:
        return None, None, None
    times = df.iloc[0,1:].astype(float).values*1000
    raw_names = df.iloc[1:,0].astype(str).values
    data = df.iloc[1:,1:].values
    idx, names = [], []
    for i, rn in enumerate(raw_names):
        ch = extract_ch_name(rn)
        if not is_aux(ch):
            idx.append(i)
            names.append(ch)
    if not idx:
        return None, None, None
    return data[idx,:], names, times

def subselect_time(mat: np.ndarray, times: np.ndarray, tmin: int, tmax: int):
    mask = (times>=tmin)&(times<tmax)
    return mat[:,mask]

def amplitude_to_dB(mat: np.ndarray, contrast: float) -> np.ndarray:
    if mat.size==0:
        return np.full(mat.shape[0], np.nan)
    pw = np.mean(mat, axis=1)*1e12 + 1e-30
    return 10*np.log10(pw)*contrast

def load_group_data(csv_dir, driver, n_subj, cond, band, tmin, tmax, contrast):
    group, names = [], None
    for sid in range(1, n_subj+1):
        fld = f"{driver}_{sid:02d}_{ENV}_{cond}_{band}"
        path = os.path.join(csv_dir, fld)
        if not os.path.isdir(path): continue
        mats=[]
        for f in sorted(os.listdir(path)):
            if not f.lower().endswith('.csv'): continue
            data, chs, times = load_single_csv(os.path.join(path,f))
            if data is None: continue
            wnd = subselect_time(data, times, tmin, tmax)
            if wnd.size==0: continue
            mats.append(wnd)
        if not mats: continue
        arr = np.concatenate(mats, axis=1)
        db = amplitude_to_dB(arr, contrast)
        if names is None:
            names = chs
        elif chs!=names:
            continue
        group.append(db)
    if not group: return None, None
    return np.vstack(group), names

def plot_topomap(data, ch_names, montage, title, out_path, vmin, vmax, cmap):
    info = mne.create_info(ch_names, sfreq=250, ch_types='eeg')
    raw = mne.io.RawArray(np.zeros((len(ch_names),10)), info)
    raw.set_montage(montage, on_missing='ignore')
    pos = layout._find_topomap_coords(raw.info, picks=np.arange(len(ch_names)), ignore_overlap=True)
    fig, ax = plt.subplots(figsize=(5,4))
    im,_ = mne.viz.plot_topomap(
        data, pos, axes=ax, show=False,
        outlines='head', sensors=True,
        sphere=0.095, contours=8,
        cmap=cmap, vlim=(vmin, vmax)
    )
    cbar = plt.colorbar(im, ax=ax, fraction=0.045, shrink=0.8)
    cbar.set_label(title, rotation=270, labelpad=10, fontsize=8)
    plt.title(title, fontsize=10)
    fig.savefig(out_path + ".png", dpi=300, bbox_inches='tight')
    fig.savefig(out_path + ".svg", format='svg', bbox_inches='tight')
    plt.close(fig)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    montage, _ = extract_montage(args.set_ref_dir, args.set_ref_file)
    cmap_seq = LinearSegmentedColormap.from_list('seq', ['white','#1F77B4'])
    cmap_div = LinearSegmentedColormap.from_list('div', ['#FF7F0E','white','#1F77B4'])
    for cond in args.conditions:
        for band in args.bands:
            for tr in args.t_ranges:
                tmin, tmax = map(int, tr.split('-'))
                e_data, chs = load_group_data(
                    args.csv_dir,'e',args.n_expert,cond,band,tmin,tmax,args.contrast
                )
                n_data, _ = load_group_data(
                    args.csv_dir,'n',args.n_novice,cond,band,tmin,tmax,args.contrast
                )
                if e_data is None or n_data is None: continue
                mean_e = np.nanmean(e_data, axis=0)
                mean_n = np.nanmean(n_data, axis=0)
                vmax = max(mean_e.max(), mean_n.max())
                vmin = 0.0
                # Expert topomap
                title_e = f"{band}_{cond}_{tmin}-{tmax}ms Expert"
                out_e = os.path.join(args.output_dir, f"Expert_{band}_{cond}_{tr}")
                plot_topomap(mean_e, chs, montage, title_e, out_e, vmin, vmax, cmap_seq)
                # Novice topomap
                title_n = f"{band}_{cond}_{tmin}-{tmax}ms Novice"
                out_n = os.path.join(args.output_dir, f"Novice_{band}_{cond}_{tr}")
                plot_topomap(mean_n, chs, montage, title_n, out_n, vmin, vmax, cmap_seq)
                # Difference topomap
                diff = mean_e - mean_n
                max_abs = np.nanmax(np.abs(diff))
                title_d = f"{band}_{cond}_{tmin}-{tmax}ms Difference"
                out_d = os.path.join(args.output_dir, f"Diff_{band}_{cond}_{tr}")
                plot_topomap(diff, chs, montage, title_d, out_d, -max_abs, max_abs, cmap_div)

if __name__ == '__main__':
    main()
