In [None]:
#!/usr/bin/env python3
"""
freq_topomap_virtual.py

Publication‑quality frequency‑domain topographic maps for the **virtual**
driving environment. Computes group‑mean power in decibels for expert and novice
subjects within user‑specified frequency bands and time windows, then visualizes:

  1. Expert group topomap (sequential cmap: white→blue)
  2. Novice group topomap (sequential cmap)
  3. Difference topomap (Expert − Novice) with diverging cmap (orange→white→blue)

Difference logic:
  ΔdB = mean_expert_dB − mean_novice_dB per electrode;
  red/blue diverging map: red=expert>novice, blue=novice>expert, white=equal.

Usage:
    python freq_topomap_virtual.py \
        --set-ref-dir /path/to/eeglab/sets \
        --set-ref-file subject01.set \
        --csv-dir /path/to/freq_csv/output \
        --output-dir /path/to/save/figures \
        --conditions h o c   \
        --bands theta gamma1 \
        --n-expert 36        \
        --n-novice 64        \
        --contrast 1.5       \
        --percentile 98      \
        --t-range 50-350 250-350
"""
import os
import argparse
import numpy as np
import pandas as pd
import mne
import matplotlib.pyplot as plt
from mne.channels import layout
from matplotlib.colors import LinearSegmentedColormap

# Constants
CHANNELS_TO_DROP = ['HEO','VEO','CB1','CB2','EKG','EMG','TRIGGER','M1','M2']
ENV = 'v'  # virtual environment code

# ----------------------------------------------------------------------------
# Parsing arguments
# ----------------------------------------------------------------------------
def parse_args():
    p = argparse.ArgumentParser(
        description="Frequency‑domain topomap visualization for virtual driving data"
    )
    p.add_argument('--set-ref-dir',  required=True,
                   help='Directory of EEGLAB .set files for montage')
    p.add_argument('--set-ref-file', required=True,
                   help='.set filename')
    p.add_argument('--csv-dir',      required=True,
                   help='Root directory of frequency CSV subfolders')
    p.add_argument('--output-dir',   required=True,
                   help='Directory to save topomap figures')
    p.add_argument('--conditions', nargs='+', choices=['h','o','c'], required=True,
                   help='Condition codes: h (hazard), o (occlusion), c (control)')
    p.add_argument('--bands', nargs='+', choices=['theta','gamma1'], required=True,
                   help='Frequency bands: theta, gamma1')
    p.add_argument('--n-expert', type=int, default=36,
                   help='Number of expert subjects')
    p.add_argument('--n-novice', type=int, default=64,
                   help='Number of novice subjects')
    p.add_argument('--contrast', type=float, default=1.5,
                   help='Contrast scaling factor for dB conversion')
    p.add_argument('--percentile', type=int, default=98,
                   help='Percentile for vmax adjustment in sequential plots')
    p.add_argument('--t-range', dest='t_ranges', nargs='+', required=True,
                   help='Time windows (ms), e.g. 50-350 250-350')
    return p.parse_args()

# ----------------------------------------------------------------------------
# Utility Functions
# ----------------------------------------------------------------------------
def rename_channel(raw: str) -> str:
    """Convert channel label to Title case, e.g. 'FC1'->'Fc1'."""
    nm = raw.strip().upper()
    return nm[0] + nm[1:].lower() if nm else nm

def is_aux(ch: str) -> bool:
    """Check if channel is auxiliary/noise."""
    return ch.upper() in (c.upper() for c in CHANNELS_TO_DROP)

def extract_ch_name(raw: str) -> str:
    """Get base channel name before any whitespace."""
    return rename_channel(raw.split()[0])

def extract_montage(set_dir: str, set_file: str):
    """
    Load .set and build custom montage excluding auxiliary channels.
    Returns (montage, channel_names).
    """
    path = os.path.join(set_dir, set_file)
    raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)
    raw.rename_channels({ch: rename_channel(ch) for ch in raw.ch_names})
    keep = [ch for ch in raw.ch_names if not is_aux(ch)]
    raw.pick(keep)
    pos = {ch: raw.info['chs'][idx]['loc'][:3]
           for idx, ch in enumerate(raw.ch_names)}
    montage = mne.channels.make_dig_montage(ch_pos=pos, coord_frame='head')
    return montage, raw.ch_names

def load_single_csv(path: str):
    """
    Load one CSV: row0 times (s), col0 names, returns (data, names, times_ms).
    """
    df = pd.read_csv(path, header=None).dropna(how='all', axis=[0,1])
    if df.shape[0]<2 or df.shape[1]<2:
        return None, None, None
    times = df.iloc[0,1:].astype(float).values * 1000
    raw_names = df.iloc[1:,0].astype(str).values
    data = df.iloc[1:,1:].values
    idx, names = [], []
    for i, nm in enumerate(raw_names):
        ch = extract_ch_name(nm)
        if not is_aux(ch):
            idx.append(i)
            names.append(ch)
    if not idx:
        return None, None, None
    return data[idx,:], names, times

def subselect_time(mat: np.ndarray, times: np.ndarray, tmin: int, tmax: int):
    """Select columns within [tmin, tmax) ms."""
    mask = (times>=tmin)&(times<tmax)
    return mat[:,mask]

def amplitude_to_dB(mat: np.ndarray, contrast: float) -> np.ndarray:
    """
    Compute mean power across time and convert to dB:
      10*log10(mean_power*1e12 + eps) * contrast
    """
    if mat.size==0:
        return np.full(mat.shape[0], np.nan)
    pw = np.mean(mat, axis=1)*1e12 + 1e-30
    return 10*np.log10(pw)*contrast

def load_group_data(csv_dir: str, driver: str, n_subj: int,
                    cond: str, band: str, tmin: int, tmax: int,
                    contrast: float) -> (np.ndarray, list):
    """
    Aggregate subject data for driver('e'/'n'), env='v', cond, band, time window.
    Returns (n_subjects×n_channels array, channel list).
    """
    group, names = [], None
    for sid in range(1, n_subj+1):
        fld = f"{driver}_{sid:02d}_{ENV}_{cond}_{band}"
        path = os.path.join(csv_dir, fld)
        if not os.path.isdir(path):
            continue
        mats=[]
        for f in sorted(os.listdir(path)):
            if not f.lower().endswith('.csv'): continue
            data, chs, times = load_single_csv(os.path.join(path,f))
            if data is None: continue
            wnd = subselect_time(data, times, tmin, tmax)
            if wnd.size==0: continue
            mats.append(wnd)
        if not mats:
            continue
        arr = np.concatenate(mats, axis=1)
        db = amplitude_to_dB(arr, contrast)
        if names is None:
            names = chs
        elif chs!=names:
            continue
        group.append(db)
    if not group:
        return None, None
    return np.vstack(group), names

def plot_topomap(data: np.ndarray, ch_names: list, montage,
                 title: str, out_path: str,
                 vmin: float, vmax: float,
                 cmap) -> None:
    """Render and save a topographic map."""
    info = mne.create_info(ch_names, sfreq=250, ch_types='eeg')
    raw = mne.io.RawArray(np.zeros((len(ch_names),10)), info)
    raw.set_montage(montage, on_missing='ignore')
    pos = layout._find_topomap_coords(raw.info, picks=np.arange(len(ch_names)), ignore_overlap=True)
    fig, ax = plt.subplots(figsize=(5,4))
    im,_ = mne.viz.plot_topomap(data, pos, axes=ax, show=False,
                                outlines='head', sensors=True,
                                sphere=0.095, contours=8,
                                cmap=cmap, vlim=(vmin, vmax))
    cbar = plt.colorbar(im, ax=ax, fraction=0.045, shrink=0.8)
    cbar.set_label(title, rotation=270, labelpad=10, fontsize=8)
    plt.title(title, fontsize=10)
    fig.savefig(out_path + ".png", dpi=300, bbox_inches='tight')
    fig.savefig(out_path + ".svg", format='svg', bbox_inches='tight')
    plt.close(fig)

# ----------------------------------------------------------------------------
# Main Workflow
# ----------------------------------------------------------------------------
def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)
    montage, _ = extract_montage(args.set_ref_dir, args.set_ref_file)
    # Colormaps
    cmap_seq = LinearSegmentedColormap.from_list('seq', ['white','#1F77B4'])
    cmap_div = LinearSegmentedColormap.from_list('div', ['#FF7F0E','white','#1F77B4'])
    for cond in args.conditions:
        cond_label = args.conditions_map[cond]
        for band in args.bands:
            band_label = args.band_map[band]
            for tr in args.t_ranges:
                tmin, tmax = map(int, tr.split('-'))
                # Load expert and novice group data
                e_data, chs = load_group_data(
                    args.csv_dir, 'e', args.n_expert, cond, band,
                    tmin, tmax, args.contrast
                )
                n_data, _ = load_group_data(
                    args.csv_dir, 'n', args.n_novice, cond, band,
                    tmin, tmax, args.contrast
                )
                if e_data is None or n_data is None:
                    continue
                mean_e = np.nanmean(e_data, axis=0)
                mean_n = np.nanmean(n_data, axis=0)
                vmax = max(mean_e.max(), mean_n.max())
                vmin = 0.0
                # Expert
                title_e = f"{band_label}_{cond_label}_{tmin}-{tmax}ms Expert"
                out_e = os.path.join(args.output_dir, f"Expert_{band}_{cond}_{tr}")
                plot_topomap(mean_e, chs, montage, title_e, out_e, vmin, vmax, cmap_seq)
                # Novice
                title_n = f"{band_label}_{cond_label}_{tmin}-{tmax}ms Novice"
                out_n = os.path.join(args.output_dir, f"Novice_{band}_{cond}_{tr}")
                plot_topomap(mean_n, chs, montage, title_n, out_n, vmin, vmax, cmap_seq)
                # Difference
                diff = mean_e - mean_n
                max_abs = np.nanmax(np.abs(diff))
                title_d = f"{band_label}_{cond_label}_{tmin}-{tmax}ms Difference"
                out_d = os.path.join(args.output_dir, f"Diff_{band}_{cond}_{tr}")
                plot_topomap(diff, chs, montage, title_d, out_d, -max_abs, max_abs, cmap_div)

if __name__ == '__main__':
    # Mapping dictionaries
    setattr(parse_args().conditions, 'conditions_map', {'h':'Overt Hazard','o':'Covert Hazard','c':'Control'})
    setattr(parse_args().bands, 'band_map', {'theta':'Theta (4-8 Hz)','gamma1':'Gamma (30-40 Hz)'})
    main()
