In [88]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
import cmlreaders as cml
from cmldask import CMLDask as da
from dask.distributed import wait, as_completed
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as mcolors
import seaborn as sns
import pandas as pd
import xarray as xr
import scipy as scp
import re
from scipy import stats
from ptsa.data.timeseries import *
from statsmodels.stats.multitest import multipletests
import pyedflib
from mne_bids import get_entity_vals
from ReportRawEEG import *

pd.options.display.max_rows = 100
pd.options.display.max_columns = 50
%matplotlib inline
import mne
from mne_bids import BIDSPath, read_raw_bids

In [89]:
subjects = ['LTP606', 'LTP607', 'LTP609', 'LTP610', 'LTP612']
experiment = 'valuecourier'
bids_root = "/home1/maint/LTP_BIDS/"

In [90]:
# Convert to DataFrame
def fix_evs_bids(full_evs):
    value_recalls = full_evs[full_evs.trial_type == "VALUE_RECALL"] 
    words = full_evs[full_evs.trial_type == "WORD"]
    rec_words = full_evs[full_evs.trial_type == "REC_WORD"]
    rec_vv_words = full_evs[full_evs.trial_type == "REC_WORD_VV"]

    # WORD --> storepointtype, recalled--> VALUE_RECALL, REC_WORD, REC_WORD_VV
    word_trial_to_storepointtype = words.set_index("trial")["storepointtype"].to_dict()
    word_trial_to_recalled = words.set_index("trial")["recalled"].to_dict()
    for event_type in ["VALUE_RECALL", "REC_WORD", "REC_WORD_VV"]:
        subset = full_evs[full_evs.trial_type == event_type]
        for idx, row in subset.iterrows():
            trial = row["trial"]
            if trial in word_trial_to_storepointtype:
                full_evs.at[idx, "storepointtype"] = word_trial_to_storepointtype[trial]
            if trial in word_trial_to_recalled:
                full_evs.at[idx, "recalled"] = word_trial_to_recalled[trial]

    # VALUE_RECALL --> actualvalue, valuerecall --> WORD, `REC_WORD`, REC_WORD_VV
    valuerecall_trial_to_actualvalue = value_recalls.set_index("trial")["actualvalue"].to_dict()
    valuerecall_trial_to_valuerecall = value_recalls.set_index("trial")["valuerecall"].to_dict()

    # --- Apply to multi-row event types ---
    for event_type in ["WORD", "REC_WORD", "REC_WORD_VV"]:
        subset = full_evs[full_evs.trial_type == event_type]
        for idx, row in subset.iterrows():
            trial = row["trial"]

            # actualvalue
            if trial in valuerecall_trial_to_actualvalue:
                full_evs.at[idx, "actualvalue"] = valuerecall_trial_to_actualvalue[trial]

            # valuerecall
            if trial in valuerecall_trial_to_valuerecall:
                full_evs.at[idx, "valuerecall"] = valuerecall_trial_to_valuerecall[trial]
                
    return full_evs

def fix_evs_cml(full_evs):
    value_recalls = full_evs[full_evs.type == "VALUE_RECALL"] 
    words = full_evs[full_evs.type == "WORD"]
    rec_words = full_evs[full_evs.type == "REC_WORD"]
    rec_vv_words = full_evs[full_evs.type == "REC_WORD_VV"]

    # WORD --> storepointtype, recalled--> VALUE_RECALL, REC_WORD, REC_WORD_VV
    word_trial_to_storepointtype = words.set_index("trial")["storepointtype"].to_dict()
    word_trial_to_recalled = words.set_index("trial")["recalled"].to_dict()
    for event_type in ["VALUE_RECALL", "REC_WORD", "REC_WORD_VV"]:
        subset = full_evs[full_evs.type == event_type]
        for idx, row in subset.iterrows():
            trial = row["trial"]
            if trial in word_trial_to_storepointtype:
                full_evs.at[idx, "storepointtype"] = word_trial_to_storepointtype[trial]
            if trial in word_trial_to_recalled:
                full_evs.at[idx, "recalled"] = word_trial_to_recalled[trial]

    # VALUE_RECALL --> actualvalue, valuerecall --> WORD, `REC_WORD`, REC_WORD_VV
    valuerecall_trial_to_actualvalue = value_recalls.set_index("trial")["actualvalue"].to_dict()
    valuerecall_trial_to_valuerecall = value_recalls.set_index("trial")["valuerecall"].to_dict()

    # --- Apply to multi-row event types ---
    for event_type in ["WORD", "REC_WORD", "REC_WORD_VV"]:
        subset = full_evs[full_evs.type == event_type]
        for idx, row in subset.iterrows():
            trial = row["trial"]

            # actualvalue
            if trial in valuerecall_trial_to_actualvalue:
                full_evs.at[idx, "actualvalue"] = valuerecall_trial_to_actualvalue[trial]

            # valuerecall
            if trial in valuerecall_trial_to_valuerecall:
                full_evs.at[idx, "valuerecall"] = valuerecall_trial_to_valuerecall[trial]
                
    return full_evs

In [91]:
# sess = 0
# subject = "LTP606"
# ### Compare BIDS vs CML Reader
# bids_root = "/home1/maint/LTP_BIDS/"
# # 1 subject
# # events
# # CML
# reader = cml.CMLReader(subject=subject, experiment="ValueCourier", session=sess)
# evs_cml = reader.load('events')
# evs_cml = fix_evs_cml(evs_cml)
# # evs_cml = evs_cml[(evs_cml['type']=='WORD') & (evs_cml['eegoffset']>=0)]

# # BIDS
# bids_path = BIDSPath(
#     subject=subject,
#     session=str(sess),
#     task="valuecourier",
#     datatype="eeg",
#     root=bids_root,
# )

# # --------------------------
# # Load BIDS events.tsv
# # --------------------------
# events_path = os.path.join(bids_path.directory, bids_path.basename + "_events.tsv")
# evs_bids = pd.read_csv(events_path, sep="\t")
# evs_bids = fix_evs_bids(evs_bids)

# # evs_bids = evs_bids.query("trial_type == 'WORD' and sample >= 0").copy()

In [92]:
# res = compare_behavioral(
#     evs_cml, "CMLReader",
#     evs_bids, "OpenBIDS",
#     options=[
#         "compare_onset_as_diff",
#         "tolerant_numeric",
#         "print_behavior_summary",
#         "print_behavior_col_summary",
#         "print_behavior_mismatches",
#         "return_col_summary",
#         "return_mismatches",
#     ],
#     drop_cols=[],  # add stuff you explicitly don't want compared
# )

In [93]:
# # check eeg
# # CML
# eeg_cml = reader.load_eeg().to_ptsa()

In [94]:
# # bids
# REL_START, REL_STOP = 200 /1000, 3000 /1000
# BUFFER_MS = 1000 /1000
# WIDTH = 6

# FREQS = np.logspace(np.log10(2), np.log10(100), 46)
# NOTCH_BAND = (58., 62.)
# BATCH_EVENTS = 64

# bids_path = BIDSPath(
#     subject="LTP606",
#     session=str(0),
#     task="valuecourier",
#     datatype="eeg",
#     root=bids_root,
# )
 
# raw = read_raw_bids(bids_path)

# raw.set_channel_types({
#     "EXG1": "eog", "EXG2": "eog", "EXG3": "eog", "EXG4": "eog",
#     "EXG5": "misc", "EXG6": "misc", "EXG7": "misc", "EXG8": "misc",
# })

# # raw.drop_channels(
# #     [ch for ch in raw.ch_names if raw.get_channel_types(picks=ch)[0] != "eeg"]
# # )

# # --------------------------
# # Epoch WORD events from annotations
# # --------------------------
# events, event_id = mne.events_from_annotations(raw)
# word_event_id = {k: v for k, v in event_id.items() if (k == "WORD")}

# # ---- CRITICAL: MNE expects SECONDS ----
# # If your constants are in ms, convert here:
# tmin = (-BUFFER_MS)
# tmax = ((REL_STOP + BUFFER_MS))

# epochs_bids = mne.Epochs(
#     raw,
#     events=events,
#     event_id=word_event_id,
#     tmin=tmin,
#     tmax=tmax,
#     baseline=(None, 0),
#     preload=True,
#     event_repeated="merge",
# )

# print("sfreq:", epochs_bids.info["sfreq"])
# print("tmin,tmax:", epochs_bids.tmin, epochs_bids.tmax)
# print("n_times:", len(epochs_bids.times))
# print("duration_s:", epochs_bids.times[-1] - epochs_bids.times[0])

# eeg_bids = TimeSeries.from_mne_epochs(epochs_bids, evs_bids)

In [95]:
# raw = read_raw_bids(
#     bids_path,
#     verbose=True,
# )

# eeg_bids = xr.DataArray(
#     raw.get_data()[None, :, :],                           # -> (1, n_channels, n_times)
#     dims=("event", "channel", "time"),         # match eeg_cml dim names
#     coords={
#         "event": [0],                          # singleton event index
#         "channel": raw.ch_names,
#         "time": raw.times  ,
#         "samplerate": raw.info["sfreq"],                    # scalar coord (optional)
#     },
#     name="eeg",
# )

In [96]:

### ACROSS MULTIPLE SUBJECTS AND SESSIONS
bids_root = "/home1/maint/LTP_BIDS/"
subjects = get_entity_vals(bids_root, "subject")


# subject
def process_raw_signals(sub, exp, sess, bids_root, out_path): # entire signal, not epoched
    ### load cml
    reader = cml.CMLReader(subject=sub, experiment=exp, session=sess)
    eeg_cml = reader.load_eeg().to_ptsa()

    ### load bdf
    # BIDS
    bids_path = BIDSPath(
        subject=sub,
        session=str(sess),
        task=exp.lower(),
        datatype="eeg",
        root=bids_root,
    )

    raw = read_raw_bids(
        bids_path,
        verbose=True,
    )

    eeg_bids = xr.DataArray(
        raw.get_data()[None, :, :],                           # -> (1, n_channels, n_times)
        dims=("event", "channel", "time"),         # match eeg_cml dim names
        coords={
            "event": [0],                          # singleton event index
            "channel": raw.ch_names,
            "time": raw.times * 1000,
            "samplerate": raw.info["sfreq"],                    # scalar coord (optional)
        },
        name="eeg",
    )
    
    # return eeg_cml, eeg_bids
    

    ## load pyedf
    # cml_bdf_path  = f"/protocols/ltp/subjects/{sub}/experiments/{exp}/sessions/{sess}/ephys/current_processed/{sub}_session_{sess}.bdf"
    # eeg_pyedflib = load_bdf_as_xarray(cml_bdf_path)

    # compare
    results = compare_eeg_sources(
        eeg_dict={"BIDS": eeg_bids, "CMLReader": eeg_cml},
        subject=sub,
        experiment=exp,
        session=sess,
        options=["strip_metadata", "compare_raw_signals", "compare_time_coords"]
    )
    
    results["df_raw"].to_csv(f"{out_path}df_raw_{sub}_{exp}_{sess}.csv", index=False)
    results["df_raw_summary"].to_csv(f"{out_path}df_raw_summary_{sub}_{exp}_{sess}.csv", index=False)
    results["df_time"].to_csv(f"{out_path}df_time_{sub}_{exp}_{sess}.csv", index=False)
    return results


In [97]:
# eeg_cml, eeg_bids = process_raw_signals("LTP606", "ValueCourier", 0, "/home1/maint/LTP_BIDS/", "")

In [98]:
def _all_exist(paths):
    return all(os.path.exists(p) for p in paths)

def process_events(sub, exp, sess, evs_types,bids_root, out_path, *, skip_if_exists=True):
    os.makedirs(out_path, exist_ok=True)

    out_behavior_summary = os.path.join(out_path, f"df_behavior_summary_{sub}_{exp}_{sess}.csv")
    # Uncomment these if you re-enable writing them:
    # out_behavior_col = os.path.join(out_path, f"df_behavior_column_summary_{sub}_{exp}_{sess}.csv")
    # out_behavior_mismatch = os.path.join(out_path, f"df_behavior_mismatches_{sub}_{exp}_{sess}.csv")
    expected = [out_behavior_summary]

    if skip_if_exists and _all_exist(expected):
        return {"skipped": True, "reason": "outputs_exist", "paths": expected}
    ### load cml
    cmlreader = cml.CMLReader(subject=sub, experiment=exp, session=sess)
    evs_cml = cmlreader.load('events')
    
    evs_types_set = set(evs_types) if evs_types is not None else set(evs_cml["type"].unique())
    # print(evs_types_set)
    if exp == "ValueCourier":
        evs_cml = fix_evs_cml(evs_cml)
    
    filtered_evs_cml = evs_cml[evs_cml["type"].isin(evs_types_set)]
    
    exp_lower = exp
    exp_lower = exp_lower.lower()
    path_bids = BIDSPath(
                subject=sub,
                session=str(sess),
                task=exp_lower,  
                datatype="beh", 
                suffix="beh",
                extension=".tsv",
                root=bids_root
            )
    evs_bids = pd.read_csv(path_bids.fpath, sep="\t")
    # bids_path = BIDSPath(
    #     subject=sub,
    #     session=str(sess),
    #     task=exp_lower,
    #     datatype="eeg",
    #     root=bids_root,
    # )

#     # --------------------------
#     # Load BIDS events.tsv
#     # --------------------------
#     events_path = os.path.join(bids_path.directory, bids_path.basename + "_events.tsv")
#     evs_bids = pd.read_csv(events_path, sep="\t")
    if exp == "ValueCourier":
        evs_bids = fix_evs_bids(evs_bids)

    filtered_evs_bids = evs_bids[evs_bids["trial_type"].isin(evs_types_set)]
    if filtered_evs_bids.empty:
        raise ValueError("Filtered events dataframe has no rows")
    try: 
        results = compare_behavioral(
            evs_cml, "CMLReader",
            evs_bids, "OpenBIDS",
            options=[
                "compare_onset_as_diff",
                "tolerant_numeric",
                "return_col_summary",
                "return_mismatches",
            ],
            drop_cols=[],  # add stuff you explicitly don't want compared
        )
        # results["df_behavior_summary"].to_csv(f"{out_path}df_behavior_summary{sub}_{exp}_{sess}.csv", index=False)
        # results["df_behavior_column_summary"].to_csv(f"{out_path}df_behavior_column_summary_{sub}_{exp}_{sess}.csv", index=False)
        # results["df_behavior_mismatches"].to_csv(f"{out_path}df_behavior_mismatches_{sub}_{exp}_{sess}.csv", index=False)
        os.makedirs(out_path, exist_ok=True)
        results["df_behavior_summary"].to_csv(
            os.path.join(out_path, f"df_behavior_summary_{sub}_{exp}_{sess}.csv"),
            index=False,
        )
        # results["df_behavior_column_summary"].to_csv(
        #     os.path.join(out_path, f"df_behavior_column_summary_{sub}_{exp}_{sess}.csv"),
        #     index=False,
        # )
        # results["df_behavior_mismatches"].to_csv(
        #     os.path.join(out_path, f"df_behavior_mismatches_{sub}_{exp}_{sess}.csv"),
        #     index=False,
        # )
        return results
    except Exception as e:
        print(f"Failed process_events: {e}")

In [99]:
# process_events("LTP606", "ValueCourier", 0, None, "/home1/maint/LTP_BIDS/", "raw_results/")

In [None]:
# from mne_bids import get_entity_vals
# from ReportRawEEG import *
# ### ACROSS MULTIPLE SUBJECTS AND SESSIONS
# bids_root = "/home1/maint/LTP_BIDS/"
# subjects = get_entity_vals(bids_root, "subject")


# # subject
# def process_epoched_signals(sub, exp, sess, evs_types,  tmin, tmax, bids_root, out_path):
#     ### load cml
#     cmlreader = cml.CMLReader(subject=sub, experiment=exp, session=sess)
#     evs_cml = cmlreader.load('events')
    
#     if evs_types is None:
#         evs_types_set = set(evs_cml["type"].unique())
#     else:
#         evs_types_set = set(evs_types)
    
    
#     filtered_evs_cml = evs_cml[evs_cml["type"].isin(evs_types_set)]
    
#     eeg_cml = cmlreader.load_eeg(filtered_evs_cml, rel_start=tmin, rel_stop=tmax).to_ptsa()
    
    
#     exp_lower = exp
#     exp_lower = exp_lower.lower()
#     ### load events
#     # path_bids = BIDSPath(
#     #             subject=sub,
#     #             session=str(sess),
#     #             task=exp_lower,  
#     #             datatype="beh", 
#     #             suffix="beh",
#     #             extension=".tsv",
#     #             root=bids_root
#     #         )
#     # evs_bids = pd.read_csv(path_bids.fpath, sep="\t")

# #     # --------------------------
# #     # Load BIDS events.tsv
# #     # --------------------------
#     bids_path = BIDSPath(
#         subject=sub,
#         session=str(sess),
#         task=exp_lower,
#         datatype="eeg",
#         root=bids_root,
#     )
#     events_path = os.path.join(bids_path.directory, bids_path.basename + "_events.tsv")
#     evs_bids = pd.read_csv(events_path, sep="\t")
#     if exp == "ValueCourier":
#         evs_bids = fix_evs_bids(evs_bids)

#     filtered_evs_bids = evs_bids[evs_bids["trial_type"].isin(evs_types_set)]
#     if filtered_evs_bids.empty:
#         raise ValueError("Filtered events dataframe has no rows")

#     raw_bids = read_raw_bids(bids_path)

#     raw_bids.set_channel_types({
#         "EXG1": "eog", "EXG2": "eog", "EXG3": "eog", "EXG4": "eog",
#         "EXG5": "misc", "EXG6": "misc", "EXG7": "misc", "EXG8": "misc",
#     })

#     # --------------------------
#     # Epoch WORD events from annotations
#     # --------------------------
#     events_bids, event_bids_id = mne.events_from_annotations(raw_bids)
#     filtered_event_bids_id = {k: v for k, v in event_bids_id.items() if (k in evs_types_set)}
#     if not filtered_event_bids_id:
#         raise ValueError("Filtered events id has no ids")
#     try: 
#         epochs_bids = mne.Epochs(
#             raw_bids,
#             events=events_bids,
#             event_id=filtered_event_bids_id,
#             tmin=tmin / 1000,
#             tmax=tmax / 1000,
#             baseline=None,
#             preload=True,
#             event_repeated="drop",
#         )
#         # del raw_bids, event_bids

#         # We only need EEG channels (exclude eog/misc)
#         picks_eeg = mne.pick_types(epochs_bids.info, eeg=True, eog=False, misc=False)
#         epochs_bids = epochs_bids.pick(picks_eeg)

#         eeg_bids = TimeSeries.from_mne_epochs(epochs_bids, filtered_evs_bids)
#         eeg_bids = eeg_bids.assign_coords(time=eeg_bids["time"] * 1000)
#         eeg_bids["time"].attrs["units"] = "ms"

#         print(eeg_bids.time.data)
#         # del epochs_bids
#         # compare
#         results = compare_eeg_sources(
#             eeg_dict={"BIDS": eeg_bids, "CMLReader": eeg_cml},
#             subject=sub,
#             experiment=exp,
#             session=sess,
#             options=["strip_metadata", "compare_raw_signals", "compare_time_coords"]
#         )
#         os.makedirs(out_path, exist_ok=True)
#         results["df_raw"].to_csv(f"{out_path}df_raw_{sub}_{exp}_{sess}.csv", index=False)
#         results["df_raw_summary"].to_csv(f"{out_path}df_raw_summary_{sub}_{exp}_{sess}.csv", index=False)
#         results["df_time"].to_csv(f"{out_path}df_time_{sub}_{exp}_{sess}.csv", index=False)
#         return results
#     except Exception as e:
#         print(f"Failed process_epoched_signals by failing to load EEG data: {e}")


In [101]:
from mne_bids import BIDSPath
import os
import gc
import numpy as np
import pandas as pd
import mne
from ReportRawEEG import *

def _all_exist(paths):
    return all(os.path.exists(p) for p in paths)

def _dedupe_events_by_sample(df: pd.DataFrame, sample_col: str, *, keep="first") -> pd.DataFrame:
    if sample_col not in df.columns:
        raise ValueError(f"Expected column '{sample_col}' in events df. Columns={list(df.columns)[:20]}")
    df2 = df.copy()
    df2[sample_col] = pd.to_numeric(df2[sample_col], errors="coerce")
    df2 = df2.dropna(subset=[sample_col])
    df2 = df2.sort_values(sample_col, kind="mergesort")
    df2 = df2[~df2[sample_col].duplicated(keep=keep)]
    return df2

def _as_list(x):
    if x is None:
        return None
    if isinstance(x, (list, tuple, set, np.ndarray, pd.Index)):
        return list(x)
    return [x]

def process_epoched_signals_by_type(
    sub,
    exp,
    sess,
    evs_types,
    tmin,
    tmax,
    bids_root,
    out_path,
    *,
    skip_if_exists=True,
    keep="first",
    verbose=False,
):
    """
    Run epoch+compare separately for each event type, append results across types,
    save and return the appended DataFrames.
    """
    os.makedirs(out_path, exist_ok=True)

    # aggregated outputs (ONE set per sub/exp/sess)
    out_raw = os.path.join(out_path, f"df_raw_{sub}_{exp}_{sess}.csv")
    out_raw_summary = os.path.join(out_path, f"df_raw_summary_{sub}_{exp}_{sess}.csv")
    out_time = os.path.join(out_path, f"df_time_{sub}_{exp}_{sess}.csv")
    expected = [out_raw, out_raw_summary, out_time]

    if skip_if_exists and _all_exist(expected):
        print("Files exist: skipped")
        return {"skipped": True, "reason": "outputs_exist", "paths": expected}

    # --------------------------
    # CML: load events once
    # --------------------------
    cmlreader = cml.CMLReader(subject=sub, experiment=exp, session=sess)
    evs_cml = cmlreader.load("events")

    # decide which types to run
    if evs_types is None:
        types_to_run = sorted(pd.unique(evs_cml["type"]))
    else:
        types_to_run = sorted(set(_as_list(evs_types)))

    if len(types_to_run) == 0:
        raise ValueError("types_to_run is empty.")

    # --------------------------
    # BIDS: load raw + annotations once
    # --------------------------
    task = exp.lower()
    bids_path = BIDSPath(
        subject=sub,
        session=str(sess),
        task=task,
        datatype="eeg",
        root=bids_root,
    )

    raw_bids = read_raw_bids(bids_path)
    raw_bids.set_channel_types({
        "EXG1": "eog", "EXG2": "eog", "EXG3": "eog", "EXG4": "eog",
        "EXG5": "misc", "EXG6": "misc", "EXG7": "misc", "EXG8": "misc",
    })

    events_all, event_id_all = mne.events_from_annotations(raw_bids)
    sfreq = float(raw_bids.info["sfreq"])

    # collect per-type outputs
    all_raw = []
    all_raw_summary = []
    all_time = []

    # optional bookkeeping
    per_type_status = []

    for etype in types_to_run:
        if verbose:
            print(f"[{sub} | {exp} | {sess}] type={etype}")

        try:
            # --------------------------
            # CML: filter to this type + dedupe by eegoffset, then epoch
            # --------------------------
            evs_cml_t = evs_cml[evs_cml["type"] == etype].copy()
            if evs_cml_t.empty:
                per_type_status.append((etype, "skip", "no_cml_events"))
                continue

            evs_cml_t = _dedupe_events_by_sample(evs_cml_t, "eegoffset", keep=keep)

            eeg_cml = cmlreader.load_eeg(evs_cml_t, rel_start=tmin, rel_stop=tmax).to_ptsa()

            # --------------------------
            # BIDS: filter annotation labels/codes for this type, dedupe by sample, epoch
            # --------------------------
            if etype not in event_id_all:
                per_type_status.append((etype, "skip", "etype_not_in_annotations"))
                # free CML epoch before continue
                del eeg_cml
                gc.collect()
                continue

            filtered_event_id = {etype: event_id_all[etype]}
            code = filtered_event_id[etype]

            events_filt = events_all[events_all[:, 2] == code]
            if len(events_filt) == 0:
                per_type_status.append((etype, "skip", "no_bids_events"))
                del eeg_cml
                gc.collect()
                continue

            # dedupe by sample
            _, first_idx = np.unique(events_filt[:, 0], return_index=True)
            events_filt = events_filt[np.sort(first_idx)]

            epochs_bids = mne.Epochs(
                raw_bids,
                events=events_filt,
                event_id=filtered_event_id,
                tmin=tmin / 1000.0,
                tmax=tmax / 1000.0,
                baseline=None,
                preload=True,
            )

            picks_eeg = mne.pick_types(epochs_bids.info, eeg=True, eog=False, misc=False)
            epochs_bids = epochs_bids.pick(picks_eeg)

            # metadata aligned to events_filt
            meta = pd.DataFrame({
                "sample": events_filt[:, 0].astype(int),
                "trial_type": [etype] * len(events_filt),
            })
            meta["onset"] = meta["sample"] / sfreq

            eeg_bids = TimeSeries.from_mne_epochs(epochs_bids, meta)
            eeg_bids = eeg_bids.assign_coords(time=eeg_bids["time"] * 1000.0)
            eeg_bids["time"].attrs["units"] = "ms"

            # --------------------------
            # Compare
            # --------------------------
            res = compare_eeg_sources(
                eeg_dict={"BIDS": eeg_bids, "CMLReader": eeg_cml},
                subject=sub,
                experiment=exp,
                session=sess,
                options=["strip_metadata", "compare_raw_signals", "compare_time_coords"],
            )

            # append dfs; add event type column so you can stratify later
            if res.get("df_raw") is not None and not res["df_raw"].empty:
                df = res["df_raw"].copy()
                df["event_type"] = etype
                all_raw.append(df)

            if res.get("df_raw_summary") is not None and not res["df_raw_summary"].empty:
                df = res["df_raw_summary"].copy()
                df["event_type"] = etype
                all_raw_summary.append(df)

            if res.get("df_time") is not None and not res["df_time"].empty:
                df = res["df_time"].copy()
                df["event_type"] = etype
                all_time.append(df)

            per_type_status.append((etype, "ok", ""))

        except Exception as e:
            per_type_status.append((etype, "fail", repr(e)))

        finally:
            # free big objects per type
            for name in ("epochs_bids", "eeg_bids", "eeg_cml", "res", "events_filt", "meta"):
                if name in locals():
                    try:
                        del locals()[name]
                    except Exception:
                        pass
            gc.collect()

    # done with BIDS raw
    try:
        raw_bids.close()
    except Exception:
        pass
    del raw_bids
    gc.collect()

    # concatenate and save
    df_raw_all = pd.concat(all_raw, ignore_index=True) if all_raw else pd.DataFrame()
    df_raw_summary_all = pd.concat(all_raw_summary, ignore_index=True) if all_raw_summary else pd.DataFrame()
    df_time_all = pd.concat(all_time, ignore_index=True) if all_time else pd.DataFrame()

    df_raw_all.to_csv(out_raw, index=False)
    df_raw_summary_all.to_csv(out_raw_summary, index=False)
    df_time_all.to_csv(out_time, index=False)

    return {
        "df_raw": df_raw_all,
        "df_raw_summary": df_raw_summary_all,
        "df_time": df_time_all,
        "per_type_status": pd.DataFrame(per_type_status, columns=["event_type", "status", "detail"]),
        "paths": expected,
    }


In [102]:
# bids_root = "/home1/maint/LTP_BIDS/"
# REL_START, REL_STOP = 200, 3000
# BUFFER_MS = 1000
# tmin = (-BUFFER_MS)
# tmax = ((REL_STOP + BUFFER_MS))
# results = process_epoched_signals("LTP606", "ValueCourier", 0, ["WORD"], tmin, tmax, bids_root, "raw_results")

In [103]:
client = da.new_dask_client_slurm(
    job_name="raw_signals",
    memory_per_job="100GB",
    max_n_jobs=20,
    queue="RAM",
    local_directory="~/scratch",
    log_directory=os.path.expanduser("~/log_directory")
)

Unique port for zrentala is 51618
{'dashboard_address': ':51618'}
To view the dashboard, run: 
`ssh -fN zrentala@rhino2.psych.upenn.edu -L 8000:192.168.86.104:38458` in your local computer's terminal (NOT rhino) 
and then navigate to localhost:8000 in your browser


Perhaps you already have a cluster running?
Hosting the HTTP server on port 38458 instead


In [104]:
max_subjects = 10
# experiments = ["ValueCourier", "ltpFR", "ltpFR2", "VFFR"]
experiments = ["ltpFR"]

subjects_to_exclude = {"LTP001", "LTP9992", "LTP9993"}  # <-- your list here

df = cml.get_data_index()

df_exp = df[df["experiment"].isin(experiments)].copy()

# remove excluded subjects up front
df_exp = df_exp[~df_exp["subject"].isin(subjects_to_exclude)].copy()

dfs = []

for exp in experiments:
    df_this = df_exp[df_exp["experiment"] == exp]

    subjects = (
        df_this["subject"]
        .drop_duplicates()
        .sort_values()      # deterministic
        .head(max_subjects)
    )

    df_keep = df_this[df_this["subject"].isin(subjects)].copy()
    dfs.append(df_keep)

df_subset = pd.concat(dfs, ignore_index=True)
df_subset

Unnamed: 0,Recognition,all_events,contacts,experiment,import_type,localization,math_events,montage,original_experiment,original_session,pairs,ps4_events,session,subject,subject_alias,system_version,task_events
0,,protocols/ltp/subjects/LTP063/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP063/experiments/ltpF...,0,,0,,,0,LTP063,LTP063,,protocols/ltp/subjects/LTP063/experiments/ltpF...
1,,protocols/ltp/subjects/LTP063/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP063/experiments/ltpF...,0,,1,,,1,LTP063,LTP063,,protocols/ltp/subjects/LTP063/experiments/ltpF...
2,,protocols/ltp/subjects/LTP063/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP063/experiments/ltpF...,0,,10,,,10,LTP063,LTP063,,protocols/ltp/subjects/LTP063/experiments/ltpF...
3,,protocols/ltp/subjects/LTP063/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP063/experiments/ltpF...,0,,11,,,11,LTP063,LTP063,,protocols/ltp/subjects/LTP063/experiments/ltpF...
4,,protocols/ltp/subjects/LTP063/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP063/experiments/ltpF...,0,,12,,,12,LTP063,LTP063,,protocols/ltp/subjects/LTP063/experiments/ltpF...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
180,,protocols/ltp/subjects/LTP074/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP074/experiments/ltpF...,0,,5,,,5,LTP074,LTP074,,protocols/ltp/subjects/LTP074/experiments/ltpF...
181,,protocols/ltp/subjects/LTP074/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP074/experiments/ltpF...,0,,6,,,6,LTP074,LTP074,,protocols/ltp/subjects/LTP074/experiments/ltpF...
182,,protocols/ltp/subjects/LTP074/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP074/experiments/ltpF...,0,,7,,,7,LTP074,LTP074,,protocols/ltp/subjects/LTP074/experiments/ltpF...
183,,protocols/ltp/subjects/LTP074/experiments/ltpF...,,ltpFR,build,0,protocols/ltp/subjects/LTP074/experiments/ltpF...,0,,8,,,8,LTP074,LTP074,,protocols/ltp/subjects/LTP074/experiments/ltpF...


In [105]:
# cmlreader = cml.CMLReader(subject='LTP063', experiment="ltpFR", session=0)
# eeg = cmlreader.load_eeg().to_ptsa()

In [106]:
# from mne_bids import get_entity_vals, BIDSPath

# bids_root = "/home1/maint/LTP_BIDS/"
# out_path = "raw_results/"
# os.makedirs(out_path, exist_ok=True)

# REL_STOP = 3000
# BUFFER_MS = 1000
# tmin = -BUFFER_MS
# tmax = REL_STOP + BUFFER_MS
# evs_type = None

# subjects = get_entity_vals(bids_root, "subject")
# print(subjects)
# # tasks = get_entity_vals(bids_root, "task")         # <-- experiment == task in BIDS
# tasks = ["ltpFR2", "ltpFR", "ValueCourier", "VFFR"]
# sessions_all = get_entity_vals(bids_root, "session")

# futures = []

# for sub in subjects:
#     # only sessions that exist for this subject (directory-based)
#     ses_for_sub = [
#         ses for ses in sessions_all
#         if os.path.isdir(os.path.join(bids_root, f"sub-{sub}", f"ses-{ses}"))
#     ]

#     # only tasks that exist for this subject+session (file-based)
#     for ses in ses_for_sub:
#         # print(ses)
#         for task in tasks:
#             bids_path = BIDSPath(
#                 subject=sub, session=str(ses), task=task,
#                 datatype="eeg", root=bids_root
#             )

#             # skip nonexistent combinations (no EEG file)
#             matches = bids_path.match()
#             if len(matches) == 0:
#                 continue
#             print(f"{sub} {task} {sess} {bids_root} {out_path}")
#             futures.append(client.submit(
#                 process_epoched_signals, sub, task, ses, evs_type, tmin, tmax, bids_root, out_path
#             ))
#             futures.append(client.submit(
#                 process_events, sub, task, ses, evs_type, bids_root, out_path
#             ))


In [107]:
# get futures
bids_root = "/home1/maint/LTP_BIDS/"
subjects = get_entity_vals(bids_root, "subject")
out_path = "raw_results_type/"
futures = []
REL_START, REL_STOP = 200, 1000
BUFFER_MS = 1000
# evs_type = ["WORD"]
evs_type = None
tmin = (-BUFFER_MS)
tmax = ((REL_STOP + BUFFER_MS))
future_meta = {} 
futures_eeg = []
for i, row in df_subset.iterrows():
    sub = row["subject"]
    exp = row["experiment"]
    sess = row["session"]
    # try:
    #     process_epoched_signals_by_type(sub, exp, sess, evs_type, tmin, tmax, bids_root, out_path, verbose=True)
    # except Exception as e:
    #     print(e)
    fut = client.submit(
        process_epoched_signals_by_type,
        sub, exp, sess, evs_type, tmin, tmax, bids_root, out_path
    )

    futures_eeg.append(fut)
    future_meta[fut.key] = (sub, exp, sess)
    # if i < 15:
    #     break

# for sub in subjects:
#     subject_root = os.path.join(bids_root, f"sub-{sub}")
#     experiments = get_entity_vals(subject_root, "experiment")
#     print(experiments)
#     sessions = get_entity_vals(subject_root, "session")
#     # futures.extend([client.submit(process_raw_signals, sub, "ValueCourier", sess, bids_root,out_path) for sess in sessions])
#     futures.extend([client.submit(process_epoched_signals, sub, exp, sess, evs_type, tmin, tmax, bids_root, out_path)for sess in sessions for exp in experiments])
#     futures.extend([client.submit(process_events, sub, exp, sess, evs_type, bids_root, out_path) for sess in sessions])
#     break
#     # process_epoched_signals(sub, exp, sess, evs_types, tmin, tmax, bids_root, out_path)

In [108]:
# run futures 
from dask.distributed import as_completed

all_df_raw = []
all_df_raw_summary = []
all_df_time = []
all_df_behavior_summary = []
# all_df_behavior_column_summary = []
# all_df_behavior_mismatches = []
# all_eegs_std = []

n_done, n_fail = 0, 0

for fut in as_completed(futures_eeg):
    sub, exp, sess = future_meta.get(fut.key, ("<unknown>", "<unknown>", "<unknown>"))

    try:
        out = fut.result()

        if out.get("skipped"):
            print(f"[SKIP] {sub} | {exp} | {sess}")
            continue

        if out["df_raw"] is not None and not out["df_raw"].empty:
            all_df_raw.append(out["df_raw"])

        if out["df_raw_summary"] is not None and not out["df_raw_summary"].empty:
            all_df_raw_summary.append(out["df_raw_summary"])

        if out["df_time"] is not None and not out["df_time"].empty:
            all_df_time.append(out["df_time"])

        n_done += 1
        print(f"[DONE] {sub} | {exp} | {sess}  ({n_done})")

    except Exception as e:
        n_fail += 1
        print(f"[FAIL] {sub} | {exp} | {sess}  -> {e}  ({n_fail})")
df_raw_all = pd.concat(all_df_raw, ignore_index=True)
df_raw_summary_all = pd.concat(all_df_raw_summary, ignore_index=True)
df_time_all = pd.concat(all_df_time, ignore_index=True)
# df_behavior_summary_all = pd.concat(all_df_behavior_summary, ignore_index=True)

# df_raw_all.to_csv(f"{out_path}df_raw_all.csv", index=False)
# df_raw_summary_all.to_csv(f"{out_path}df_raw_summary_all.csv", index=False)
# df_time_all.to_csv(f"{out_path}df_time_all.csv", index=False)
# df_behavior_summary_all.to_csv(f"{out_path}df_behavior_summary_all.csv", index=False)

[FAIL] LTP063 | ltpFR | 1  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-1/eeg'  (1)
[FAIL] LTP063 | ltpFR | 0  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-0/eeg'  (2)
[FAIL] LTP063 | ltpFR | 18  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-18/eeg'  (3)
[FAIL] LTP063 | ltpFR | 19  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-19/eeg'  (4)
[FAIL] LTP063 | ltpFR | 11  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-11/eeg'  (5)
[FAIL] LTP063 | ltpFR | 2  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-2/eeg'  (6)
[FAIL] LTP063 | ltpFR | 10  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-10/eeg'  (7)
[FAIL] LTP063 | ltpFR | 3  -> [Errno 2] No such file or directory: '/home1/maint/LTP_BIDS/sub-LTP063/ses-3/eeg'  (8)
[FAIL] LTP063 | ltpFR | 13  -> [Errno 2] No such file or

ValueError: No objects to concatenate

In [None]:
futures_beh = []
for i, row in df_subset.iterrows():
    sub = row["subject"]
    exp = row["experiment"]
    sess = row["session"]
    futures_beh.append(client.submit(process_events, sub, exp, sess, evs_type, bids_root, out_path))

In [None]:
# run futures 
from dask.distributed import as_completed

all_df_behavior_summary = []

n_done, n_fail = 0, 0

for fut in as_completed(futures_beh):
    try:
        out = fut.result()
        
        if out["df_behavior_summary"] is not None and not out["df_behavior_summary"].empty:
            all_df_behavior_summary.append(out["df_behavior_summary"])

        # Append eeg_std list
        # all_eegs_std.extend(out["eegs_std"])

        n_done += 1
        print(f"[DONE] {n_done}")

    except Exception as e:
        n_fail += 1
        print(f"[ERR] Future failed ({n_fail} fails): {e}")
df_behavior_summary_all = pd.concat(all_df_behavior_summary, ignore_index=True)

df_behavior_summary_all.to_csv(f"{out_path}df_behavior_summary_all.csv", index=False)

In [None]:
import os
import pandas as pd
out_path = "raw_results/"
df_raw_filenames = []
df_raw_summary_filenames = []
df_time_filenames = []
df_behavior_summary_filenames = []

for dirpath, _, filenames in os.walk(out_path):
    for f in filenames:
        full_path = os.path.join(dirpath, f)
        
        # Categorize based on string patterns
        if f.startswith('df_raw_summary_') and f.endswith('.csv'):
            df_raw_summary_filenames.append(full_path)
        elif f.startswith('df_raw_') and f.endswith('.csv'):
            df_raw_filenames.append(full_path)
        elif f.startswith('df_time_') and f.endswith('.csv'):
            df_time_filenames.append(full_path)
        elif f.startswith('df_behavior_summary_') and f.endswith('.csv'):
            df_behavior_summary_filenames.append(full_path)
            
def load_and_concat(file_list):
    if not file_list:
        return pd.DataFrame()  # Return empty DF if no files found
    # Read each CSV and combine them into one
    return pd.concat([pd.read_csv(f) for f in file_list], ignore_index=True)

# Create the 3 distinct DataFrames
df_raw_all = load_and_concat(df_raw_filenames)
df_raw_summary_all = load_and_concat(df_raw_summary_filenames)
df_time_all = load_and_concat(df_time_filenames)
df_behavior_summary_all = load_and_concat(df_behavior_summary_filenames)


# df_raw_all.to_csv("df_raw_all.csv", index=False)
# df_raw_summary_all.to_csv("df_raw_summary_all.csv", index=False)
# df_time_all.to_csv("df_time_all.csv", index=False)

In [None]:
# plot mean and std difference
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

comparisons = df_time_all['comparison'].unique()
subjects = df_time_all['subject'].unique()

fig, axes = plt.subplots(1, len(comparisons), figsize=(6 * len(comparisons), 5), sharex=True)
if len(comparisons) == 1: axes = [axes]

for i, comp in enumerate(comparisons):
    ax = axes[i]
    comp_df = df_time_all[df_time_all['comparison'] == comp]
    
    for subj in subjects:
        subj_df = comp_df[comp_df['subject'] == subj].sort_values('session')
        if subj_df.empty: continue
        
        # Plot mean line
        line, = ax.plot(subj_df['session'], subj_df['mean_abs_time_diff'], marker='o', label=subj)
        
        # Add shaded Std region
        ax.fill_between(
            subj_df['session'], 
            subj_df['mean_abs_time_diff'] - subj_df['std_time_diff'],
            subj_df['mean_abs_time_diff'] + subj_df['std_time_diff'],
            color=line.get_color(), 
            alpha=0.15
        )
    
    ax.set_title(comp)
    ax.set_xlabel('Session')
    if i == 0: ax.set_ylabel('Mean Abs Time Diff ($\pm$ Std)')
    ax.legend(title='Subject')

plt.tight_layout()
plt.show()

In [None]:
# plot mse
fig, axes = plt.subplots(1, len(comparisons), figsize=(6 * len(comparisons), 5), sharex=True)
if len(comparisons) == 1: axes = [axes]

for i, comp in enumerate(comparisons):
    ax = axes[i]
    comp_df = df_time_all[df_time_all['comparison'] == comp]
    
    for subj in subjects:
        subj_df = comp_df[comp_df['subject'] == subj].sort_values('session')
        if subj_df.empty: continue
        
        # Plot mean line
        line, = ax.plot(subj_df['session'], subj_df['mse_time'], marker='o', label=subj)
        
#         # Add shaded Std region
#         ax.fill_between(
#             subj_df['session'], 
#             subj_df['mean_abs_time_diff'] - subj_df['std_time_diff'],
#             subj_df['mean_abs_time_diff'] + subj_df['std_time_diff'],
#             color=line.get_color(), 
#             alpha=0.15
#         )
    
    ax.set_title(comp)
    ax.set_xlabel('Session')
    if i == 0: ax.set_ylabel('MSE Time ($\pm$ Std)')
    ax.legend(title='Subject')

plt.tight_layout()
plt.show()

In [None]:
comparisons = df_raw_summary_all['comparison'].unique()
subjects = df_raw_summary_all['subject'].unique()

fig, axes = plt.subplots(1, len(comparisons), figsize=(6 * len(comparisons), 5), sharex=True)
if len(comparisons) == 1: axes = [axes]

for i, comp in enumerate(comparisons):
    ax = axes[i]
    comp_df = df_raw_summary_all[df_raw_summary_all['comparison'] == comp]
    
    for subj in subjects:
        subj_df = comp_df[comp_df['subject'] == subj].sort_values('session')
        if subj_df.empty: continue
        
        # Plot mean line
        line, = ax.plot(subj_df['session'], subj_df['mean_abs_diff'], marker='o', label=subj)
        
        # Add shaded Std region
        ax.fill_between(
            subj_df['session'], 
            subj_df['mean_abs_diff'] - subj_df['std_diff'],
            subj_df['mean_abs_diff'] + subj_df['std_diff'],
            color=line.get_color(), 
            alpha=0.15
        )
    
    ax.set_title(comp)
    ax.set_xlabel('Session')
    if i == 0: ax.set_ylabel('Mean Abs Signal Diff ($\pm$ Std)')
    ax.legend(title='Subject')

plt.tight_layout()
plt.show()

In [None]:
# plot mse
fig, axes = plt.subplots(1, len(comparisons), figsize=(6 * len(comparisons), 5), sharex=True)
if len(comparisons) == 1: axes = [axes]

for i, comp in enumerate(comparisons):
    ax = axes[i]
    comp_df = df_raw_summary_all[df_raw_summary_all['comparison'] == comp]
    
    for subj in subjects:
        subj_df = comp_df[comp_df['subject'] == subj].sort_values('session')
        if subj_df.empty: continue
        
        # Plot mean line
        line, = ax.plot(subj_df['session'], subj_df['mse'], marker='o', label=subj)
        
#         # Add shaded Std region
#         ax.fill_between(
#             subj_df['session'], 
#             subj_df['mean_abs_time_diff'] - subj_df['std_time_diff'],
#             subj_df['mean_abs_time_diff'] + subj_df['std_time_diff'],
#             color=line.get_color(), 
#             alpha=0.15
#         )
    
    ax.set_title(comp)
    ax.set_xlabel('Session')
    if i == 0: ax.set_ylabel('MSE Raw Signal ($\pm$ Std)')
    ax.legend(title='Subject')

plt.tight_layout()
plt.show()

In [None]:
# plot n_channels diff
fig, axes = plt.subplots(1, len(comparisons), figsize=(6 * len(comparisons), 5), sharex=True)
if len(comparisons) == 1: axes = [axes]

for i, comp in enumerate(comparisons):
    ax = axes[i]
    comp_df = df_raw_summary_all[df_raw_summary_all['comparison'] == comp]
    
    for subj in subjects:
        subj_df = comp_df[comp_df['subject'] == subj].sort_values('session')
        if subj_df.empty: continue
        
        # Plot mean line
        line, = ax.plot(subj_df['session'], subj_df['n_exact_diff_channels'], marker='o', label=f"{subj} exact")
        line, = ax.plot(subj_df['session'], subj_df['n_close_diff_channels'], marker='o', label=f"{subj} close")
        
#         # Add shaded Std region
#         ax.fill_between(
#             subj_df['session'], 
#             subj_df['mean_abs_time_diff'] - subj_df['std_time_diff'],
#             subj_df['mean_abs_time_diff'] + subj_df['std_time_diff'],
#             color=line.get_color(), 
#             alpha=0.15
#         )
    
    ax.set_title(comp)
    ax.set_xlabel('Session')
    if i == 0: ax.set_ylabel('n_channels different ($\pm$ Std)')
    ax.legend(title='Subject')

plt.tight_layout()
plt.show()