In [None]:
import numpy as np
import cedalion
import cedalion.nirs
import cedalion.datasets
import cedalion.plots
import xarray as xr
from cedalion import units
import cedalion.models.glm as glm
from sklearn.model_selection import StratifiedKFold
import pickle
import os
import pandas as pd
import matplotlib.pyplot as plt
import cedalion.sigproc.motion_correct as motion_correct
from cedalion.sigproc.quality import repair_amp
import cedalion.models.glm.design_matrix as glm_dm
import configs as data_configs

xr.set_options(display_expand_data=False);
path_prefix = data_configs.data_path_prefix

In [None]:
def create_cv_splits(stim_df, n_splits):
    y = stim_df['trial_type']
    skf = StratifiedKFold(n_splits=n_splits)
    splits = []
    for train_index, test_index in skf.split(np.zeros(len(y)), y):
        train_stim = stim_df.iloc[train_index]
        test_stim = stim_df.iloc[test_index]
        splits.append({'train_stim': train_stim, 'test_stim': test_stim})
    return splits

In [None]:
def fit_glm_excluding_test(ts, test_stim, dm, mask_before_secs, mask_after_secs, method='individual', noise_model='ols'):

    dm_modified = dm.common.copy()

    channel_wise_regs_modified = None

    if dm.channel_wise:
        channel_wise_regs_modified = [dm.channel_wise[0].copy()]
        unit = dm.channel_wise[0].pint.units
    else:
        channel_wise_regs_modified = []

    # Old Method: Zero out the design matrix from before first test stimulus to after last test stimulus
    # Problem: There can be train stims in between that get zeroed out, if the trial types don't always occur alternating
    if method == 'block':
        # Identify the earliest and latest test stimulus onset times
        te_0 = test_stim['onset'].min()
        te_1 = test_stim['onset'].max()
        duration = test_stim.loc[test_stim['onset'] == te_1, 'duration'].values[0]

        # Define the time window to exclude (test data plus epochs around stimuli)
        exclude_start = te_0 - mask_before_secs  
        exclude_end = te_1 + duration + mask_after_secs   

        # Zero out the design matrix rows corresponding to the test data timepoints
        mask = (dm_modified['time'] >= exclude_start) & (dm_modified['time'] <= exclude_end)
        dm_modified = dm_modified.where(~mask, other=0)
        if dm.channel_wise:
            channel_wise_regs_modified[0] = channel_wise_regs_modified[0].where(~mask, other=0 * unit)


    # New Method: Zero out the design matrix around each test stimulus.
    if method == 'individual':
        for _, stim in test_stim.iterrows():
            onset = stim['onset']
            duration = stim['duration']
            # Define mask from onset to onset + duration
            mask = (dm_modified['time'] >= onset - mask_before_secs) & (dm_modified['time'] <= onset + duration + mask_after_secs)
            #mask = xr.DataArray(mask, dims='time', coords={'time': dm['time']})
            dm_modified = dm_modified.where(~mask, other=0)
            if dm.channel_wise:
                channel_wise_regs_modified[0] = channel_wise_regs_modified[0].where(~mask, other=0 * unit)

    
    updated_dm = glm_dm.DesignMatrix(
        dm_modified,
        channel_wise_regs_modified,
    )

    # Fit the GLM using the modified design matrix
    reg_results = glm.fit(ts, updated_dm, noise_model=noise_model)
    reg_results = reg_results.sm.params
    #print("reg_results: ", reg_results)
    return reg_results, updated_dm

In [None]:
def prepare_epochs_labels(ts, stim_df, cv_splits, trial_types, before_secs, max_stim_dur, after_secs, dm=None, exclude_test=True, glm_test_method='individual', noise_model='ols'):
    split_metadata = []
    epochs_list = []
    epo_after_secs = max_stim_dur + after_secs

    use_shared_epochs = (dm is None) or (exclude_test is False)

    print("Using shared epochs:", use_shared_epochs)
    print("dm:", dm)

    # Precompute GLM + epochs once if applicable
    if use_shared_epochs:
        ts_data = ts.copy()

        if dm is not None:
            reg_results = glm.fit(ts_data, dm, noise_model=noise_model)
            betas = reg_results.sm.params

            pred_wo_hrf = glm.predict(ts, betas.sel(regressor=~betas.regressor.str.startswith("HRF ")), dm)

            if ts_data.pint.units == cedalion.units.micromolar:
                pred_wo_hrf = pred_wo_hrf.pint.quantify("micromolar")

            ts_data = ts_data - pred_wo_hrf

        # Extract epochs once
        shared_epochs = ts_data.cd.to_epochs(
            stim_df,
            trial_types,
            before=before_secs * units.seconds,
            after=epo_after_secs * units.seconds
        )

        baseline = shared_epochs.sel(reltime=(shared_epochs.reltime < 0)).mean("reltime")
        shared_epochs = shared_epochs - baseline

    for split in cv_splits:
        train_stim = split['train_stim']
        test_stim = split['test_stim']

        diff_hrf = None
        hrf_diff_sorted = None

        if use_shared_epochs:
            epochs = shared_epochs
        else:
            ts_data = ts.copy()

            # Fit GLM with test-exclusion
            betas, dm_modified = fit_glm_excluding_test(ts_data, test_stim, dm, before_secs, after_secs, method=glm_test_method, noise_model=noise_model)

            pred_wo_hrf = glm.predict(ts, betas.sel(regressor=~betas.regressor.str.startswith("HRF ")), dm)

            if ts_data.pint.units == cedalion.units.micromolar:
                pred_wo_hrf = pred_wo_hrf.pint.quantify("micromolar")

            # debug plots
            #if 'chromo' in ts_data.dims:
                # plot ts data before and after removing prediction without HRF (for random channel)
                #print("ts data: ", ts_data)

                #rch = 10
                #plt.figure(figsize=(12, 6))
                #plt.plot(ts_data.sel(time=slice(0,100)).time.values , ts_data.sel(chromo='HbO', time=slice(0,100)).isel(channel=rch).values, label='Original TS Data', alpha=0.7)
                #plt.plot(ts_data.sel(time=slice(0,100)).time.values, pred_wo_hrf.sel(chromo='HbO', time=slice(0,100)).isel(channel=rch).values, label='Predicted without HRF', alpha=0.7)
                #plt.plot(ts_data.sel(time=slice(0,100)).time.values, (ts_data - pred_wo_hrf).sel(chromo='HbO', time=slice(0,100)).isel(channel=rch).values, label='TS Data after Removing Prediction without HRF', alpha=0.7)
                #pred_w_hrf = glm.predict(ts, betas.sel(regressor=betas.regressor.str.startswith("HRF ")), dm)
                #plt.plot(ts_data.sel(time=slice(0,100)).time.values, pred_w_hrf.sel(chromo='HbO', time=slice(0,100)).isel(channel=rch).values, label='Predicted with HRF', alpha=0.7)
                #for _, stim in stim_df.iterrows():
                #    if stim['onset'] < 100:
                #        plt.axvline(x=stim['onset'], color='red', linestyle='--', alpha=0.5)
                #plt.xlabel('Time (s)')
                #plt.ylabel('HbO Concentration (micromolar)')
                #plt.title(f'Channel: {ts_data.channel.values[rch]}')
                #plt.legend()
                #plt.show()

            ts_data = ts_data - pred_wo_hrf

            # Per-split epoch extraction
            epochs = ts_data.cd.to_epochs(
                stim_df,
                trial_types,
                before=before_secs * units.seconds,
                after=epo_after_secs * units.seconds
            )

            baseline = epochs.sel(reltime=(epochs.reltime < 0)).mean("reltime")
            epochs = epochs - baseline

            regs = dm.common.regressor
            if 'chromo' in dm.common.dims:
                hrf_regs = regs.sel(regressor=regs.regressor.str.startswith("HRF")).values
                if hrf_regs.size > 2:
                    # Select the HRF weights for the two trial types
                    hrf1 = betas.sel(regressor=betas.regressor.str.startswith("HRF " + trial_types[0]), chromo='HbO')
                    hrf2 = betas.sel(regressor=betas.regressor.str.startswith("HRF " + trial_types[1]), chromo='HbO')
                    # Square differences and sum over 'regressor' to get squared L2 norm per channel
                    diff_hrf = ((hrf1 - hrf2) ** 2).sum(dim='regressor')
                    #diff_hrf = (betas.sel(regressor=betas.regressor.str.startswith("HRF " + trial_types[0]), chromo='HbO') - (betas.sel(regressor=betas.regressor.str.startswith("HRF " + trial_types[1]), chromo='HbO')))
                else:
                    diff_hrf = betas.sel(regressor=hrf_regs[0], chromo='HbO') - betas.sel(regressor=hrf_regs[1], chromo='HbO')
                sorted_indices = np.argsort(diff_hrf.values)[::-1]
                hrf_diff_sorted = list(diff_hrf.channel[sorted_indices].values)


        split_metadata.append({
            'train_indices': train_stim.index.values,
            'test_indices': test_stim.index.values,
            'y': epochs['trial_type'].values,
            'hrf_diff': diff_hrf
        })

        if not use_shared_epochs:
            epochs_list.append(epochs)

    epochs_labels = {
        'epochs': shared_epochs if use_shared_epochs else epochs_list,
        'splits': split_metadata
    }

    return epochs_labels

In [None]:
stim_vs_rest = True
save = True
test = True
before_secs = 2
after_secs = 8
dataset_configs = data_configs.load_dataset_configs(["HD_Squeezing", "Syn_Finger_Tapping", "BS_Laura"], load_sensitivity=False, test=test)
dist_thresh = 2.6 # cm threshold for short vs long channels

print("SAVE = ", save)
bs_laura_only_roi = True
if bs_laura_only_roi:
    with open(path_prefix + 'BS_Laura/BS_Laura_YY_parcel_sens_channels', 'rb') as f:
        channel_roi_bs_laura = pickle.load(f)
    print("len channel_roi_laura: ", len(channel_roi_bs_laura))


for data_type in ['Syn_Finger_Tapping']:
    
    synthetic = data_type.startswith("Syn")
    spatial_scales = [1, 2, 3] if synthetic else [1]

    for spatial_scaling in spatial_scales:

        glm_test_m = 'block' if synthetic else 'individual'
        intensity_keys = ["", "_02", "_03"] if synthetic else [""]

        # Load data
        recs = {}
        runs = {}

        config = dataset_configs.get(data_type)
        if config is None:
            raise ValueError(f"Unknown data_type: {data_type}")
        
        base_path = config.base_path

        subjects = config.subjects
        with open(base_path + config.subsets_path, 'rb') as file:
            subsets_data = pickle.load(file)
        subset_keys = list(subsets_data.keys())

        print("SUBSET KEYS:")
        print(subset_keys)

        print("SUBJECTS:", subjects)
        for subject in subjects:

            print("")
            print(f"Subject {subject}")
            print("")

            run_list = config.runs(subject)
            print(subject)
            recs[subject] = []

            for run_idx, run in enumerate(run_list):

                print("")
                print(f"Run {run_idx}")
                print("")

                path = config.snirf_path.format(subject=subject, run=run, spatial_scale=spatial_scaling)

                clean_chs_path = config.clean_channels_path(subject=subject, run=run_idx)

                clean_chs_full_path = base_path + clean_chs_path
                
                with open(clean_chs_full_path, 'rb') as f:
                    clean_channels = pickle.load(f)

                print(len(clean_channels))

                rec = cedalion.io.read_snirf(base_path + path)[0]

                print("PATH", base_path + path)

                dpf = xr.DataArray(
                    [6, 6],
                    dims="wavelength",
                    coords={"wavelength": rec["amp"].wavelength},
                )

                # Stimulus processing
                stim_df = rec.stim.copy()

                if data_type == 'BS_Laura':
                    stim_path = config.stim_template.format(subject=subject, run=run)
                    stim_df = pd.read_csv(base_path + stim_path, sep='\t')

                if config.preprocess_stim is not None:
                    stim_df = config.preprocess_stim(stim_df)
                    if not stim_df.empty and "onset" in stim_df.columns:
                        stim_df = stim_df.reset_index(drop=True)

                # Add rest vs stim if needed
                #if config.get("add_rest", False) and stim_vs_rest:
                #    df = stim_df.copy()
                #    df['trial_type'] = 'stim'
                #    rest_df = df.copy()
                #    rest_df['onset'] = rest_df['onset'] + 20
                #    rest_df['trial_type'] = 'rest'
                #    rest_df['duration'] = 7
                #    df_combined = pd.concat([df, rest_df], ignore_index=True).sort_values(by='onset').reset_index(drop=True)
                #    df_combined = df_combined[df_combined['onset'] != df_combined['onset'].iloc[-1]]
                #    stim_df = df_combined

                for int_key in intensity_keys:
                    
                    print(f"Processing intensity: '{int_key or 'base'}'")

                    if synthetic:
                        od = rec["od" + int_key]
                    elif data_type == 'BS_Laura':
                        rec["amp"] = repair_amp(rec["amp"], median_len=0)
                        od = cedalion.nirs.int2od(rec['amp'])
                    else:
                        rec["amp"] = repair_amp(rec["amp"])
                        od = cedalion.nirs.int2od(rec['amp'])
                    if bs_laura_only_roi and data_type == 'BS_Laura':
                        print("OD BEFORE")
                        print(od.shape)
                        od = od.sel(channel=channel_roi_bs_laura)
                        print("OD AFTER")
                        print(od.shape)
                    od = motion_correct.tddr(od)
                    od = motion_correct.motion_correct_wavelet(od)
                    od_bp = od.cd.freq_filter(fmin=0.02, fmax=0.5, butter_order=4)
                    od_hp = od.cd.freq_filter(fmin=0.02, fmax=0, butter_order=4)
                    dpf = xr.DataArray([6, 6], dims="wavelength", coords={"wavelength": rec["amp"].wavelength})
                    conc_bp = cedalion.nirs.od2conc(od_bp, rec.geo3d, dpf, spectrum="prahl")
                    conc_hp = cedalion.nirs.od2conc(od_hp, rec.geo3d, dpf, spectrum="prahl")

                    splits = create_cv_splits(stim_df, 5)

                    trial_types = list(set(stim_df.trial_type.values))
                    max_stim_dur = round(stim_df.duration.max())
                    print(stim_df)
                    print("max stim dur: ", max_stim_dur)

                    # Band-passed data is the 'default'.
                    od = od_bp
                    conc = conc_bp

                    epochs_labels_data = {}

                    print("")
                    print("STIM DF")
                    print(stim_df)
                    print("")

                    # Extract sparse subsets
                    for sub_key in subset_keys:
                
                        # Select sparse subset channels
                        subset_channels = subsets_data[sub_key]['all']

                        print("Subset size: ", len(subset_channels))
                        print("OD Size", od.channel.values.size)

                        od = od.sel(channel=[c for c in od.channel.values if c in subset_channels])
                        od_hp = od_hp.sel(channel=[c for c in od_hp.channel.values if c in subset_channels])
                        conc = conc.sel(channel=[c for c in conc.channel.values if c in subset_channels])
                        conc_hp = conc_hp.sel(channel=[c for c in conc_hp.channel.values if c in subset_channels])

                        ts_long, ts_short = cedalion.nirs.split_long_short_channels(
                            conc, rec.geo3d, distance_threshold=dist_thresh * units.cm
                        )

                        conc_clean = conc.sel(channel=[c for c in conc.channel.values if c in clean_channels])
                        od_clean = od.sel(channel=[c for c in conc.channel.values if c in clean_channels])

                        ts_short_clean = ts_short.sel(channel=[c for c in ts_short.channel.values if c in clean_channels])
                        ts_long_clean = ts_long.sel(channel=[c for c in ts_long.channel.values if c in clean_channels])

                        ts_long_hp, ts_short_hp = cedalion.nirs.split_long_short_channels(
                            conc_hp, rec.geo3d, distance_threshold=dist_thresh * units.cm
                        )
                        ts_short_hp_clean = ts_short_hp.sel(channel=[c for c in ts_short_hp.channel.values if c in clean_channels])
                        ts_long_hp_clean = ts_long_hp.sel(channel=[c for c in ts_long_hp.channel.values if c in clean_channels])

                        ts_long_od, ts_short_od = cedalion.nirs.split_long_short_channels(
                            od, rec.geo3d, distance_threshold=dist_thresh * units.cm
                        )
                        ts_short_od_clean = ts_short_od.sel(channel=[c for c in ts_short_od.channel.values if c in clean_channels])

                        ss_reg = ts_short_clean
                        ss_reg_od = ts_short_od_clean
                        if sub_key == 'subset_3':
                            # special case for sparsest subset since we only have very few or no channels:
                            # just take mean signal from all channels as proxy for physiological noise
                            ss_reg = conc_clean
                            ss_reg_od = od_clean

                        sigma_val = 3
                        T_val = 2
                    
                        # Design matrices

                        # we only need this data once (we can select sparse subset channels later) (saves memory if we only save it once for all)
                        if sub_key == 'full':
                            epochs_labels_data['all'] = prepare_epochs_labels(conc, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs)
                            if not data_type.startswith("Syn"):
                                epochs_labels_data['all_od'] = prepare_epochs_labels(od, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs)

                        
                        dm_mean = (
                            glm.design_matrix.hrf_regressors(
                                ts_long, stim_df, glm.Gamma(tau=0 * units.s, sigma=sigma_val * units.s, T=T_val * units.s)
                            )
                            & glm.design_matrix.drift_regressors(ts_long, drift_order=1)
                            & glm.design_matrix.average_short_channel_regressor(ss_reg)
                        )

                        dm_mean_loso = (
                            glm.design_matrix.drift_regressors(ts_long, drift_order=1)
                            & glm.design_matrix.average_short_channel_regressor(ss_reg)
                        )

                        dm_mean_od = (
                            glm.design_matrix.hrf_regressors(
                                od, stim_df, glm.Gamma(tau=0 * units.s, sigma=sigma_val * units.s, T=T_val * units.s)
                            )
                            & glm.design_matrix.drift_regressors(od, drift_order=1)
                            & glm.design_matrix.average_short_channel_regressor(ss_reg_od)
                        )

                        dm_mean_od_loso = (
                            glm.design_matrix.drift_regressors(od, drift_order=1)
                            & glm.design_matrix.average_short_channel_regressor(ss_reg_od)
                        )

                        # Extract epochs & labels for different data & design matrices

                        epochs_labels_data[sub_key] = {}

                        # the ss-regressed data differs for each subset
                        #print("all_od_ss_mean")

                        #print("long_ss_mean")
                        epochs_labels_data[sub_key]['long_ss_mean'] = prepare_epochs_labels(ts_long, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean, glm_test_method=glm_test_m)
                        epochs_labels_data[sub_key]['all_ss_mean'] = prepare_epochs_labels(conc, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean, glm_test_method=glm_test_m)

                        if not data_type.startswith("Syn"):
                            epochs_labels_data[sub_key]['all_od_ss_mean'] = prepare_epochs_labels(od, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean_od, glm_test_method=glm_test_m)
                            epochs_labels_data[sub_key]['all_od_ss_mean_full'] = prepare_epochs_labels(od, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean_od_loso, glm_test_method=glm_test_m, exclude_test=False)
                            epochs_labels_data[sub_key]['all_ss_mean_full'] = prepare_epochs_labels(conc, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean_loso, glm_test_method=glm_test_m, exclude_test=False)
                        #epochs_labels_data['long_ss_corr'] = prepare_epochs_labels(ts_long, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_corr, glm_test_method=glm_test_m)
                        #epochs_labels_data['long_ss_corr_full'] = prepare_epochs_labels(ts_long, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_corr_loso, glm_test_method=glm_test_m, exclude_test=False)
                        #epochs_labels_data['long_ss_corr_ar'] = prepare_epochs_labels(ts_long_hp, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean_hp, glm_test_method=glm_test_m, noise_model='ar_irls')
                        #epochs_labels_data['long_ss_corr_ar_full'] = prepare_epochs_labels(ts_long_hp, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean_hp_loso, glm_test_method=glm_test_m, noise_model='ar_irls', exclude_test=False)
                        #epochs_labels_data['long_long_mean'] = prepare_epochs_labels(ts_long, stim_df, splits, trial_types, before_secs, max_stim_dur, after_secs, dm=dm_mean_long, glm_test_method=glm_test_m)
                        
                    print("------------------------------------")
                    print("END KEYS")
                    print("------------------------------------")
                    print(epochs_labels_data.keys())
                    # Save results
                    if save:
                        intensity_folder = ""
                        if synthetic:
                            intensity_folder = "01" if int_key == "" else int_key.lstrip("_")
                        epo_label_path = config.epochs_labels_path(subject=subject, run=run_idx, int_scaling=intensity_folder, spatial_scaling=spatial_scaling)
                        full_output_path = os.path.join(base_path, epo_label_path)
                        print("Saving to: ", full_output_path)
                        os.makedirs(os.path.dirname(full_output_path), exist_ok=True)


                        #filename = f"run{run_idx}_epochs_labels.pkl"
                        with open(full_output_path, 'wb') as f:
                            pickle.dump(epochs_labels_data, f)

                        print(f"Saved: {full_output_path}")