In [None]:
# eeg_analysis_enhanced.py

# === 1. Dependencies ===
!pip install --quiet mne pyxdf autoreject

import os
import numpy as np
import pandas as pd
import mne
import pyxdf
from mne.preprocessing import ICA, create_eog_epochs
from mne.time_frequency import psd_array_welch, tfr_morlet
from autoreject import AutoReject

# === 2. Helpers ===

def load_eeg_from_xdf(path):
    streams, _ = pyxdf.load_xdf(path)
    eegs = [s for s in streams if s['info']['type'][0]=='EEG']
    if not eegs:
        raise RuntimeError(f"No EEG in {path}")
    e = eegs[0]
    data = np.array(e['time_series']).T
    sfreq = float(e['info']['nominal_srate'][0])
    chs = [ch['label'][0] for ch in e['info']['desc'][0]['channels'][0]['channel']]
    info = mne.create_info(ch_names=chs, sfreq=sfreq, ch_types=['eeg']*len(chs))
    return mne.io.RawArray(data, info)


def preprocess_raw(raw, l_freq=1., h_freq=40., notch=(50,100), ica_comp=0.95):
    # 1) Select EEG-only channels
    raw.pick_types(eeg=True)

    # 2) Drop any channels containing NaN/Inf
    nan_chan = [ch for ch in raw.ch_names
                if not np.all(np.isfinite(raw.get_data(picks=ch)))]
    if nan_chan:
        raw.drop_channels(nan_chan)

    # 3) Set 10-20 montage
    mont = mne.channels.make_standard_montage('standard_1020')
    raw.set_montage(mont, on_missing='ignore')

    # 4) Notch and band-pass filtering
    sf = raw.info['sfreq']; nyq = sf/2
    freqs = [f for f in notch if f < nyq]
    if freqs:
        raw.notch_filter(freqs, picks='eeg')
    raw.filter(l_freq, h_freq, picks='eeg', fir_design='firwin')

    # 5) Drop any newly NaN channels after filtering
    nan_chan2 = [ch for ch in raw.ch_names
                 if not np.all(np.isfinite(raw.get_data(picks=ch)))]
    if nan_chan2:
        raw.drop_channels(nan_chan2)

    # 6) Mark channels with extreme variance (>3 SD) as bad
    data = raw.get_data()
    var = np.var(data, axis=1)
    mean_var, std_var = np.mean(var), np.std(var)
    variance_bads = [raw.ch_names[i] for i,v in enumerate(var)
                     if abs(v-mean_var)>3*std_var]
    raw.info['bads'] = variance_bads

    # 7) Interpolate bad channels (fallback to drop if interpolation fails)
    try:
        raw.interpolate_bads(reset_bads=True)
    except ValueError:
        raw.drop_channels(raw.info['bads'])
        raw.info['bads'] = []

    # 8) Re-reference to average
    raw.set_eeg_reference('average', projection=False)

    # 9) Run ICA to remove EOG artifacts
    ica = ICA(n_components=ica_comp, method='fastica', random_state=42)
    ica.fit(raw)
    try:
        eog_epochs = create_eog_epochs(raw, ch_name='EOG')
        eog_inds, _ = ica.find_bads_eog(eog_epochs)
        ica.exclude = eog_inds
    except Exception:
        pass
    raw = ica.apply(raw.copy())
    return raw


def compute_band_power(data, sfreq):
    mask = (~np.isnan(data).any(axis=1)) & (data.std(axis=1)>1e-12)
    D = data[mask]
    if D.size == 0:
        return dict.fromkeys(['delta','theta','alpha','beta','gamma'], np.nan)
    D -= D.mean(axis=1, keepdims=True)
    N = min(2048, D.shape[1])
    psd, freqs = psd_array_welch(D, sfreq=sfreq, fmin=1, fmax=40, n_fft=N)
    return {
        'delta': psd[:, (freqs>=1)&(freqs<4)].mean(),
        'theta': psd[:, (freqs>=4)&(freqs<8)].mean(),
        'alpha': psd[:, (freqs>=8)&(freqs<13)].mean(),
        'beta' : psd[:, (freqs>=13)&(freqs<30)].mean(),
        'gamma': psd[:, (freqs>=30)&(freqs<=40)].mean()
    }

# === 3. Main Processing ===

data_dir = "/content/nanoe_eeg_data"
participants = [f"P{ii:02d}" for ii in range(12,21)]
results = []

for pid in participants:
    base = os.path.join(data_dir, pid)

    # 3a) Eye tasks (fixed windows, ERP & TFR)
    for task in ['eye_closure','eye_opening']:
        for cond in ['with','without']:
            fp = os.path.join(base, f"{task}_{cond}.xdf")
            if not os.path.isfile(fp): continue
            raw = load_eeg_from_xdf(fp)
            raw = preprocess_raw(raw)

            # Fixed-length epochs & auto-reject
            events = mne.make_fixed_length_events(raw, id=1, duration=2.0)
            epochs = mne.Epochs(raw, events, event_id=1, tmin=0, tmax=2.0,
                                baseline=None, preload=True)
            ar = AutoReject()
            epochs_clean = ar.fit_transform(epochs)

            evoked = epochs_clean.average()
            evoked.save(f"{pid}_{task}_{cond}-ave.fif", overwrite=True)

            freqs = np.logspace(np.log10(4), np.log10(40), num=20)
            tfr = tfr_morlet(epochs_clean, freqs=freqs,
                             n_cycles=freqs/2, return_itc=False)
            tfr.save(f"{pid}_{task}_{cond}-tfr.h5", overwrite=True)

            bp = compute_band_power(raw.get_data(), raw.info['sfreq'])
            bp.update(participant=pid, task=task, condition=cond, event=None)
            results.append(bp)

    # 3b) Driving tasks (event-based)
    for cond in ['with','without']:
        fp = os.path.join(base, f"driving_{cond}.xdf")
        if not os.path.isfile(fp): continue
        raw = load_eeg_from_xdf(fp)
        raw = preprocess_raw(raw)
        sfreq = raw.info['sfreq']

        streams, _ = pyxdf.load_xdf(fp)
        markers = [s for s in streams if s['info']['type'][0]=='Markers']
        all_ts = np.concatenate([s['time_stamps'] for s in markers])
        all_lbl = np.concatenate([[v[0] for v in s['time_series']] for s in markers])
        rel_t = all_ts - all_ts[0]
        samples = np.round(rel_t * sfreq).astype(int)

        events = np.vstack([samples, np.zeros_like(samples), all_lbl.astype(int)]).T
        epochs = mne.Epochs(raw, events, event_id=None, tmin=0, tmax=1.0,
                            baseline=(None,0), preload=True)
        ar = AutoReject()
        epochs_clean = ar.fit_transform(epochs)

        for lbl in np.unique(all_lbl.astype(int)):
            evk = epochs_clean[f"{lbl}"].average()
            evk.save(f"{pid}_driving_{cond}_{lbl}-ave.fif", overwrite=True)
            tfr = tfr_morlet(epochs_clean[f"{lbl}"], freqs=freqs,
                             n_cycles=freqs/2, return_itc=False)
            tfr.save(f"{pid}_driving_{cond}_{lbl}-tfr.h5", overwrite=True)

            data_reshaped = epochs_clean[f"{lbl}"].get_data().reshape(-1, sfreq)
            bp = compute_band_power(data_reshaped, sfreq)
            bp.update(participant=pid, task='driving', condition=cond, event=lbl)
            results.append(bp)

# === 4. Aggregate & Save ===

df = pd.DataFrame(results)[['delta','theta','alpha','beta','gamma','participant','task','condition','event']]
df.to_csv("all_participants_eventwise_enhanced.csv", index=False)
print("✅ Enhanced analysis complete and saved.")
