In [19]:
"""This script estimates TRFs for several models and saves them"""
from pathlib import Path
import re

In [20]:
import eelbrain
import mne

In [21]:

DATA_ROOT = Path("~").expanduser() / 'Data' / 'cocoha'
PREDICTOR_DIR = DATA_ROOT / 'predictors'
EEG_DIR = DATA_ROOT / 'eeg'
SUBJECTS = [path.name for path in EEG_DIR.iterdir() if re.match(r'S\d+', path.name)]

# Define a target directory for TRF estimates and make sure the directory is created
TRF_DIR = DATA_ROOT / 'TRFs'
TRF_DIR.mkdir(exist_ok=True)

STIMULUS_DIR = DATA_ROOT / 'stimuli'
STIMULI_PATHS = [stimulus.stem for stimulus in STIMULUS_DIR.glob("*.wav")]

In [103]:
gammatone = {}
for path in STIMULI_PATHS:
    x = eelbrain.load.unpickle(PREDICTOR_DIR / f'{path}~gammatone-8.pickle')
    x = x.bin(1/64, dim='time', label='start')
    x = eelbrain.pad(x, tstart=0, name='gammatone')
    x = eelbrain.filter_data(x, 0.5, 20)
    gammatone[path] = x
    

In [14]:
envelope = {}
onset_envelope = {}
gammatone_onsets = {}
gammatone_lin = {}
gammatone_pow = {}
word_onsets = {}
word_lexical = {}
word_nlexical = {}

for path in STIMULI_PATHS:
    
    x = eelbrain.load.unpickle(PREDICTOR_DIR / f'{path}~gammatone-1.pickle')
    x = x.bin(1/64, dim='time', label='start')
    x = eelbrain.pad(x, tstart=0, name='envelope')
    x = eelbrain.filter_data(x, 0.5, 20)
    envelope[path] = x

    x_on = eelbrain.load.unpickle(PREDICTOR_DIR / f'{path}~gammatone-on-1.pickle')
    x_on = x_on.bin(1/64, dim='time', label='start')
    x_on = eelbrain.pad(x_on, tstart=0, name='onset')
    x_on = eelbrain.filter_data(x_on, 0.5, 20)
    onset_envelope[path] = x_on


In [None]:
# Models
# ------
# Pre-define models here to have easier access during estimation. In the future, additional models could be added here and the script re-run to generate additional TRFs.
models = {
    'envelope': envelope,
    # Compare different scales for the acoustic response
    'gammatone': gammatone,
    #'gammatone-lin': gammatone_lin,
    #'gammatone-pow': gammatone_pow,
    #'gammatone-lin+log': {
    #    'gammatone_lin': gammatone_lin, 
    #    'gammatone': gammatone
    #},
    # The acoustic edge detection model
    'envelope+onset': {
        'envelope': envelope,
        'onset': onset_envelope
    },
    #'acoustic': [gammatone, gammatone_onsets],
    # Models with word-onsets and word-class
    #'words': word_onsets,
    #'words+lexical': [word_onsets, word_lexical, word_nlexical],
    #'acoustic+words': [gammatone, gammatone_onsets, word_onsets],
    #'acoustic+words+lexical': [gammatone, gammatone_onsets, word_onsets, word_lexical, word_nlexical],
}

In [107]:
expinfo_table = eelbrain.load.tsv("../expinfo.csv", encoding='utf-8-sig').as_dataframe()
expinfo_table.head()

Unnamed: 0,attend_mf,attend_lr,acoustic_condition,n_speakers,wavfile_male,wavfile_female,trigger
0,1,2,1,1,'dss.wav','',254
1,2,1,2,2,'aske_story4_trial_1.wav','marianne_story3_trial_1.wav',133
2,2,2,1,2,'aske_story4_trial_2.wav','marianne_story3_trial_2.wav',135
3,1,1,1,1,'dss.wav','',252
4,2,2,2,2,'aske_story4_trial_3.wav','marianne_story3_trial_3.wav',150


In [109]:
# Estimate TRFs
# -------------
# Loop through subjects to estimate TRFs
for subject in SUBJECTS:
    subject_trf_dir = TRF_DIR / subject
    subject_trf_dir.mkdir(exist_ok=True)
    # Generate all TRF paths so we can check whether any new TRFs need to be estimated
    trf_paths = {model: subject_trf_dir / f'{subject} {model}.pickle' for model in models}
    # Skip this subject if all files already exist
    if all(path.exists() for path in trf_paths.values()):
        continue
    # Load the EEG data
    raw = mne.io.read_raw(EEG_DIR / f'{subject}' / f'{subject}-raw.fif', preload=True)
    # Band-pass filter the raw data between 0.5 and 20 Hz
    raw.filter(0.5, 20, n_jobs=1)
    # Interpolate bad channels
    raw.interpolate_bads()


    
    # # Extract the events marking the stimulus presentation from the EEG file
    # events = eelbrain.load.fiff.events(raw)
    # # Not all subjects have all trials; determine which stimuli are present
    # trial_indexes = [STIMULI.index(stimulus) for stimulus in events['event']]
    # # Extract the EEG data segments corresponding to the stimuli
    # trial_durations = [durations[i] for i in trial_indexes]
    # eeg = eelbrain.load.fiff.variable_length_epochs(events, -0.100, trial_durations, decim=5, connectivity='auto')
    # # Since trials are of unequal length, we will concatenate them for the TRF estimation.
    # eeg_concatenated = eelbrain.concatenate(eeg)



    # EEG data from cocoha already has been epoched and concatenated, so we can directly load the concatenated data
    eeg = eelbrain.load.mne.raw_ndvar(raw) # Load the raw data as NDVar 64 sensor, 192000 time


    for model, predictor_dicts in models.items():
        trf_path = trf_paths[model]

        if trf_path.exists():
            continue

        print(f"Estimating: {subject} ~ {model}")

        model_predictors = []  # one entry per predictor type

        # Check if predictor_dicts is a nested dict (multiple predictors) or a simple dict (single predictor)
        if isinstance(list(predictor_dicts.values())[0], dict):
            # Multiple predictors: predictor_dicts has keys like 'envelope', 'onset'
            predictor_types = predictor_dicts
        else:
            # Single predictor: wrap in a list to iterate once
            predictor_types = {'predictor': predictor_dicts}

        for predictor_name, predictor_dict in predictor_types.items():

            trial_predictors = []

            for _, row in expinfo_table.iterrows():

                if row['n_speakers'] == 1:
                    continue

                if row['attend_mf'] == 1:
                    wavfile = row['wavfile_male']
                else:
                    wavfile = row['wavfile_female']

                wavfile = wavfile.strip("'")
                path_key = wavfile.removesuffix('.wav')

                trial_predictors.append(predictor_dict[path_key])

            predictor_long = eelbrain.concatenate(trial_predictors)

            print(f"  Predictor '{predictor_name}' trials: {len(trial_predictors)}")
            print(f"  Predictor '{predictor_name}' duration: {predictor_long.time.tstop:.2f}s")

            model_predictors.append(predictor_long)

        # -------------------
        # Final sanity check
        # -------------------
        print("EEG duration:", eeg.time.tstop)
        print("Predictor duration:", model_predictors[0].time.tstop)

        # Fit TRF
        trf = eelbrain.boosting(
            eeg,
            model_predictors,
            -0.100,
            1.000,
            error='l1',
            basis=0.050,
            partitions=5,
            test=1,
            selective_stopping=True
        )

        eelbrain.save.pickle(trf, trf_path)


  raw.interpolate_bads()


Estimating: S4_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S4_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S4_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S3_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S3_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S3_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S15_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S15_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S15_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S12_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S12_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S12_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S13_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S13_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S13_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S14_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S14_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S14_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S2_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S2_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S2_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S5_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S5_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S5_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S8_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S8_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S8_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S18_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S18_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S18_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S9_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S9_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S9_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S16_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S16_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S16_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S11_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S11_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S11_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S7_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S7_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S7_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S1_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S1_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S1_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S6_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S6_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S6_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S10_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S10_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S10_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0


  raw.interpolate_bads()


Estimating: S17_data_preproc ~ envelope
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S17_data_preproc ~ gammatone
  Predictor 'predictor' trials: 60
  Predictor 'predictor' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
Estimating: S17_data_preproc ~ envelope+onset
  Predictor 'envelope' trials: 60
  Predictor 'envelope' duration: 3000.00s
  Predictor 'onset' trials: 60
  Predictor 'onset' duration: 3000.00s
EEG duration: 3000.0
Predictor duration: 3000.0
