In [None]:
import numpy as np
import cedalion
import cedalion.nirs
import cedalion.datasets
import cedalion.plots
import xarray as xr
import cedalion.sim.synthetic_hrf as synHRF_ced
from cedalion import units
import cedalion.models.glm as glm
from sklearn.model_selection import StratifiedKFold
import pickle
import os
import cedalion.sigproc.motion_correct as motion_correct
import cedalion.models.glm.design_matrix as glm_dm
import matplotlib.pyplot as plt
from cedalion.sigproc.quality import repair_amp
import sys
import os
sys.path.append(os.path.join(os.getcwd(), '..', 'experiments'))
import configs

xr.set_options(display_expand_data=False);

## Loading resting state dataset

In [None]:
#syn_hrf_type = 'Stroop'
syn_hrf_type = 'Finger_Tapping'
multiple = True # different hrfs for each subject?
spatial_scaling = 3
data_config = configs.load_dataset_configs(['Syn_Finger_Tapping'], load_sensitivity=False)['Syn_Finger_Tapping']

In [None]:
data_path = configs.data_path_prefix
dataset = 'NN22_Resting_State/'
rs_data_path = data_path + dataset + 'NN22_RS/'
subjects = data_config.subjects
syn_hrf_data_path = data_path + dataset + 'Full_SynHRF_Data/' + syn_hrf_type + '/'
if multiple:
    num_multiple = len(subjects)
#subjects = ['sub-02']
# first 8 subjects have 2 runs, the rest only 1
subject_has_two_runs = [True, True, True, True, True, False, False, False, False, False, False]
runs_2 = ['run-01', 'run-02']
runs_1 = ['run-01']

In [None]:
# Load the data
recs = {}
for subject in subjects:
    recs[subject] = []
    runs = runs_1
    if subject_has_two_runs[subjects.index(subject)]:
        runs = runs_2
    for run in runs:
        recs[subject].append(cedalion.io.read_snirf(rs_data_path + subject + '/nirs/' + subject + '_task-RS_' + run + '_nirs.snirf')[0])

In [None]:
# load netcdf file
if multiple:
    spatial_hrfs = []
    for i in range(num_multiple):
        filename = f'NN22_Resting_State/NN22_syn_act/{syn_hrf_type}/multiple/syn_hrf_{syn_hrf_type}_chan_sp_{spatial_scaling}_int_1_num{i}.nc'
        print(filename)
        ds = xr.open_dataset(data_path + filename)
        spatial_hrfs.append(ds.__xarray_dataarray_variable__)
    s_hrf = spatial_hrfs[0]
else:
    filename = f'NN22_Resting_State/NN22_syn_act/{syn_hrf_type}/syn_hrf_{syn_hrf_type}_chan_sp_{spatial_scaling}_int_1.nc'
    #filename = f'NN22_Resting_State/NN22_syn_act/{syn_hrf_type}/syn_hrf_Finger_Tapping_parcel_SomMotA_7_LH_6_RH_int_1_num.nc'
    ds = xr.open_dataset(data_path + filename)
    spatial_hrf = ds.__xarray_dataarray_variable__
    s_hrf = spatial_hrf

In [None]:
s_hrf.max()

In [None]:
roi_c3_c4_path = os.path.join(data_path, 'NN22_Resting_State', "NN_22_C3_C4_close_channels")

with open(roi_c3_c4_path, 'rb') as f:
    channel_roi_c3_c4 = pickle.load(f)


if syn_hrf_type == 'Finger_Tapping':
    roi_channels = channel_roi_c3_c4

print(f'Number of channels in roi: {len(roi_channels)}')

In [None]:
rec = recs['sub-02'][0]
amp = rec['amp']
geo3d = rec.geo3d

In [None]:
dpf = xr.DataArray(
    [6, 6],
    dims="wavelength",
    coords={"wavelength": rec["amp"].wavelength},
)

In [None]:
import cedalion.validators as validators

In [None]:
int_scales = [0.2, 0.4, 0.6]

In [None]:
amp.channel

In [None]:
trial_types = list(s_hrf.trial_type.values)

In [None]:
stim_dur = 5

In [None]:
basis_fct = glm.Gamma(tau=0 * units.s, sigma=3 * units.s, T=3 * units.s)

In [None]:
for subject_id, subject in enumerate(subjects):
    print(f"Subject: {subject}")

    runs_idx = [0]
    if subject_has_two_runs[subjects.index(subject)]:
        runs_idx = [0,1]

    if multiple:
        spatial_hrf = spatial_hrfs[subject_id]
    else:
        spatial_hrf = s_hrf

    for run in runs_idx:

        print(f"Run: {run}")

        rec = recs[subject][run]
        rec['amp'].time.attrs["units"] = "second"
        rec['amp'] = rec['amp'].pint.dequantify().pint.quantify("V")  
        rec['amp'] = repair_amp(rec["amp"])
        rec["amp"] = rec["amp"].sel(channel=roi_channels)

        od = cedalion.nirs.int2od(rec["amp"])

        #stim_df = synHRF_ced.build_stim_df(
        #    max_time=od.time.values[-1] * units.seconds,
        #    trial_types=trial_types,
        #    min_interval=(stim_dur + 5) * units.seconds,
        #    max_interval=(stim_dur + 10) * units.seconds,
        #    min_stim_dur = (stim_dur - 2) * units.seconds,
        #    max_stim_dur = (stim_dur + 2) * units.seconds,
        #    min_stim_value = 0.5,
        #    max_stim_value = 1.5,
        #    order="alternating",
        #)

        # save stim df to file
        #output_path_df = syn_hrf_data_path + "stim_dfs/" + subject + '/run' + str(run) + 'df.pickle'
        #if not os.path.exists(os.path.dirname(output_path_df)):
        #    os.makedirs(os.path.dirname(output_path_df))
        #with open(output_path_df, 'wb') as f:
        #    pickle.dump(stim_df, f)
        #print(f"Saved stim_df rec to {output_path_df}")

        path_df = data_path + dataset + 'Full_SynHRF_Data/' + syn_hrf_type + '/stim_dfs/' + subject + '/run' + str(run) + 'df.pickle'
        print("Load Stim DF from " + path_df)
        with open(path_df, 'rb') as f:
            stim_df = pickle.load(f)
        

        # masked spatial_hrf
        spatial_hrf_masked = spatial_hrf.sel(channel=od.channel)

        syn_ts = synHRF_ced.build_synthetic_hrf_timeseries(od, stim_df, basis_fct, spatial_hrf_masked)
        syn_ts_sum = syn_ts.sum(dim='trial_type')
        syn_ts_sum['source'] = od.source
        syn_ts_sum['detector'] = od.detector 

        rec.stim = stim_df
        print(f"Stimulus DataFrame:\n{rec.stim}")

        for int_scaling in int_scales:
            syn_ts_sum_scaled = syn_ts_sum * int_scaling

            od_w_hrf = od + syn_ts_sum_scaled

            rec[f'od_w_hrf_{str(int_scaling)}'] = od_w_hrf


        # save the modified rec to file
        output_path = syn_hrf_data_path + "sp_" + str(spatial_scaling) + "/" + subject + '/nirs/' + subject + '_task-SynHRF_' + str(run) + '_nirs.snirf'
        if not os.path.exists(os.path.dirname(output_path)):
            os.makedirs(os.path.dirname(output_path))
        cedalion.io.write_snirf(output_path, rec)
        print(f"Saved modified rec to {output_path}")


In [None]:
syn_hrf_data_path