In [2]:
import os
import glob
import numpy as np
import pandas as pd
import xarray as xr
import pickle as pkl

import mne
import cmlreaders as cml
import cluster_helper.cluster
import matplotlib.pyplot as plt

from constants import FR2_valid_subjects, subjects_powerfile_elsewhere

In [8]:
problematic_sessions = [] # ['LTP357','LTP360','LTP361','LTP365','LTP366']
for s in FR2_valid_subjects:
    print(s, end=' ')
    nchan = 128 if int(s[-3:]) > 330 else 124
    if s=='LTP360' or s=='LTP361':
        files = glob.glob('scratch/mtpower/%s/%s_enc_logmtpowerts_*.pkl'%(s,s)) + glob.glob('/scratch/liyuxuan/ltpFR2/mtpower/%s/%s_ret_logmtpowerts_*.pkl'%(s,s))
    else:
        files = glob.glob('/scratch/liyuxuan/ltpFR2/mtpower/%s/%s_*_logmtpowerts_*.pkl'%(s,s))
    for f in files:
        x = pkl.load(open(f,'rb'))
        if len(x.channels.values)!=nchan:
            print('uh oh', s, f, len(x.channels.values))
            problematic_sessions.append(f)

LTP093 LTP106 LTP115 LTP117 LTP123 LTP133 LTP138 LTP207 LTP210 LTP228 LTP229 LTP236 LTP246 LTP249 LTP250 LTP251 LTP258 LTP259 LTP265 LTP269 LTP273 LTP278 LTP279 LTP280 LTP283 LTP285 LTP287 LTP293 LTP296 LTP297 LTP299 LTP301 LTP302 LTP303 LTP304 LTP305 LTP306 LTP307 LTP310 LTP311 LTP312 LTP316 LTP317 LTP318 LTP321 LTP322 LTP323 LTP324 LTP325 LTP326 LTP327 LTP328 LTP329 LTP331 LTP334 LTP336 LTP339 LTP341 LTP342 LTP343 LTP344 LTP346 LTP347 LTP348 LTP349 LTP354 LTP355 LTP357 uh oh LTP357 /scratch/liyuxuan/ltpFR2/mtpower/LTP357/LTP357_ret_logmtpowerts_3.pkl 263
uh oh LTP357 /scratch/liyuxuan/ltpFR2/mtpower/LTP357/LTP357_ret_logmtpowerts_6.pkl 135
uh oh LTP357 /scratch/liyuxuan/ltpFR2/mtpower/LTP357/LTP357_enc_logmtpowerts_3.pkl 263
uh oh LTP357 /scratch/liyuxuan/ltpFR2/mtpower/LTP357/LTP357_enc_logmtpowerts_6.pkl 135
LTP360 uh oh LTP360 scratch/mtpower/LTP360/LTP360_enc_logmtpowerts_3.pkl 135
uh oh LTP360 scratch/mtpower/LTP360/LTP360_enc_logmtpowerts_7.pkl 263
uh oh LTP360 /scratch/liyuxua

In [11]:
def compute_power(subject):
    
    session = int(str.split(subject,'-')[1])
    subject = str.split(subject,'-')[0]
    
    task_phase_indicator = 'ret' # 'enc' or 'ret'
    
    save_dir = 'scratch/mtpower/' # '/scratch/liyuxuan/ltpFR2/mtpower/'
    
    import os
    import glob
    import numpy as np
    import pandas as pd
    import xarray as xr
    import pickle as pkl
    
    import mne
    import cmlreaders as cml
    
    from constants import frequencies as freqs_keep, all_biosemi_channels
    
    # timepoints of interest
    # 500ms window
    # 50ms step size
    if task_phase_indicator == 'enc':
        # covering 250ms ~ 1350ms covering the encoding window of 0~1600ms
        timepoints = np.arange(250, 1350+1, 50)
    if task_phase_indicator == 'ret':
        timepoints = np.arange(-750, 0+1, 50)
    moving_window_size = 500 # ms
    
    def process_raw_eeg(subject, session):
        
        eegpath = '/protocols/ltp/subjects/%s/experiments/ltpFR2/sessions/%d/ephys/current_processed/' % (subject, session)
        sys = 'bio' if int(subject[-3:]) > 330 else 'egi'

        if sys=='egi':
            eegfile = glob.glob(eegpath+'*.2.raw') + glob.glob(eegpath+'*.1.raw') + glob.glob(eegpath+'*.mff')
            if len(eegfile)!=1: return
            eegfile = eegfile[0]
            raw = mne.io.read_raw_egi(eegfile, preload=True)
            raw.rename_channels({'E129': 'Cz'})
            raw.set_montage(mne.channels.read_montage('GSN-HydroCel-129'))
            raw.set_channel_types({'E8': 'eog', 'E25': 'eog', 'E126': 'eog', 'E127': 'eog', 'Cz': 'misc'})
        if sys=='bio':
            eegfile = glob.glob(eegpath+'*.bdf')
            if len(eegfile)!=1: return
            eegfile = eegfile[0]
            raw = mne.io.read_raw_edf(eegfile, 
                                      eog=['EXG1', 'EXG2', 'EXG3', 'EXG4'],
                                      misc=['EXG5', 'EXG6', 'EXG7', 'EXG8'],
                                      stim_channel='Status',
                                      montage='biosemi128',
                                      preload=True) # needs to be true for 0.1Hz high-pass filter to work
        
        badchanfiles = glob.glob(eegpath+'*_bad_chan[0-2].txt')
        if len(badchanfiles) > 0:
            bad = []
            for bcf in badchanfiles:
                with open(bcf, 'r') as f:
                    bad = bad + [s.strip() for s in f.readlines()]
            raw.info['bads'] = bad

        # high-pass filter
        raw.filter(l_freq=0.1, h_freq=None) # fir
        # line noise filter
        raw.filter(62.0, 58.0, method='iir', iir_params=dict(ftype='butter', order=4, output='sos'))
        
        return raw
    
    def compute_session_power(subject, session):
        
        reader = cml.CMLReader(subject=subject, 
                               experiment='ltpFR2', 
                               session=session)
        events = reader.load('events')
        if task_phase_indicator == 'enc':
            events = events[events['type']=='WORD']
        if task_phase_indicator == 'ret':
            events = events[events['type']=='REC_WORD']
        
        raw = process_raw_eeg(subject, session)
        
        # compute power at each timepoint of interest
        # -- ensuring matching returning frequencies
        power_allinterval = []
        for t in timepoints:
        
            mne_events = np.zeros((len(events), 3), dtype=int)
            mne_events[:, 0] = [o for i, o in enumerate(events['eegoffset'])]
            epochs = mne.Epochs(raw, mne_events, 
                                tmin=(t-moving_window_size/2)/1000.0, 
                                tmax=(t+moving_window_size/2)/1000.0+0.2, 
                                baseline=None, preload=True, on_missing='ignore')

            epochs._data = epochs._data * 1000000 # convert to microvolts
            
            # epochs.pick_types(eeg=True, exclude=[])
            # ^ doesn't work for some Biosemi sessions with montage set to weird settings when recording
            # replaing this line with the following to explicitly select the subset of channels we want before avg ref
            epochs.pick_channels(ch_names=all_biosemi_channels)
            
            # use custom avg reference so bad channels are excluded in computing avg but still referenced
            channel_avg = epochs.copy().pick_types(eeg=True, exclude='bads')._data.mean(1)
            epochs._data = epochs._data - np.repeat(np.expand_dims(channel_avg, 1), 
                                                    len(epochs.info['ch_names']), 
                                                    axis=1)
            epochs.resample(500.0) # resample to 500hz
            
            tminind = 0 # np.where(np.isclose(epochs.times, t/1000-moving_window_size/1000/2))[0][0]
            tmaxind = tminind + 250 # 0.5s in 500hz space --> 250 samples
            x = mne.EpochsArray(epochs.get_data()[:,:,tminind:tmaxind], epochs.info, verbose=False)

            x.info['bads'] = [] # keep bad channels in computing power
            power, fdone = mne.time_frequency.psd_multitaper(x, 
                                                             fmin=2.0, 
                                                             fmax=128.0, 
                                                             verbose=False)
            power = xr.DataArray(power,
                                 dims=('events','channels','frequency'),
                                 coords={'events':events.to_records() if type(events)!=np.recarray else events,
                                         'channels':epochs.info['ch_names'],
                                         'frequency':fdone})
            power = power.sel(frequency=freqs_keep)
            power_allinterval.append(power)

        # concat into times x events x channels x frequencies
        power = xr.concat(power_allinterval, dim='time')
        power.coords['time'] = timepoints
        del events, epochs

        # post-power-computation processing
        power = np.log10(power)
        power.values = power.values.astype(np.float32)
        
        return power
    
    path = save_dir+'%s/' % subject
    if not os.path.exists(path):
        os.mkdir(path)
    
    outf = save_dir + '%s/%s_%s_logmtpowerts_%d.pkl' % (subject, subject, task_phase_indicator, session)
    power = compute_session_power(subject, session)
    pkl.dump(power, open(outf, 'wb'))
    
#     for session in range(24):
        
#         outf = save_dir + '%s/%s_%s_logmtpowerts_%d.pkl' % (subject, subject, task_phase_indicator, session)
        
#         if os.path.exists(outf):
#             continue
        
#         try:
#             power = compute_session_power(subject, session)
            
#             # save this session
#             # pkl.dump(power, open(outf, 'wb'))
#             print(outf)
#             return power
        
#         except Exception as e:
#             print(e)
#             continue

In [None]:
# subjects = FR2_valid_subjects
# print('computing power for ', subjects, '......')
# njobs = 30
# cpj = 4

problematic_sessions = ['LTP357-3','LTP357-6',
                        'LTP360-3','LTP360-7',
                        'LTP361-4','LTP361-9','LTP361-12',
                        'LTP365-14',
                        'LTP366-0']
njobs = len(problematic_sessions)
cpj = 4

with cluster_helper.cluster.cluster_view(scheduler='sge',
                                         queue='RAM.q', 
                                         num_jobs=njobs, 
                                         cores_per_job=cpj) as view:
    view.map(compute_power, problematic_sessions)

9 Engines running


In [9]:
def concat_power(subject):
    
    # global flag
    task = 'ret' # enc or ret
    
    import os
    import glob
    import numpy as np
    import xarray as xr
    import pickle as pkl
    
    def special_concat(arr_of_ts):
        '''
        Can be used to concatenate timeseries object that contains the CML events dimension
        To get around the problem when PTSA concatenates multiple timeseries object the events dim
        will turn out to have dype='O'
        Assumes the timeseries objects to have a 'events' dim
        '''
        if len(arr_of_ts)==1:
            return arr_of_ts[0]
        else:
            x = xr.concat(arr_of_ts, dim='events')
            events = arr_of_ts[0].events.values
            for ts in arr_of_ts[1:]:
                events = np.concatenate((events, ts.events.values))
            x.coords['events'] = events
            return x

    def filter_recalls(x):
        '''x is assumed to be a recall EEG timeseries
           filters x based on some recall exclusion criteria
        '''
        events = x.events.values
        keep_index = np.zeros(len(events), dtype=bool)

        # read events in again to get REC_VV vocalization events
        reader = cml.CMLReader(subject=np.unique(events['subject'])[0], 
                               experiment='ltpFR2',
                               session=np.unique(events['session'])[0])
        original_events = reader.load('events')
        original_events = original_events[(original_events['type']=='REC_WORD') | (original_events['type']=='REC_WORD_VV')]

        for trial in np.unique(events['trial']):
            trial_recs = events[events['trial']==trial]
            trial_recs_vvs = original_events[original_events['trial']==trial]

            # get timebefore w.r.t. onset of last recall for rec_word events
            rectimes = trial_recs_vvs['rectime']
            timebefore = np.diff(np.append([0], rectimes))
            timebefore = timebefore[trial_recs_vvs['type']=='REC_WORD']

            trial_valid_rec_flag = np.ones(len(trial_recs), dtype=bool)
            trial_valid_rec_flag[timebefore<1000] = 0
            
            keep_index[events['trial']==trial] = trial_valid_rec_flag

        x = x.sel(events=keep_index)

        return x

    if task == 'enc':
    
        files = glob.glob('scratch/mtpower/%s/*_enc_logmtpowerts*.pkl' % subject)
        # files =glob.glob('/scratch/liyuxuan/ltpFR2/mtpower/%s/*_enc_logmtpowerts*.pkl' % subject)

        power_all = []
        for f in files:
            
            try:
                x = pkl.load(open(f,'rb'))
                
                # technically covers 0 ~ 1600
                x = x.sel(time=(x.time>=250)&(x.time<=1350)).mean('time')

                means = x.mean('events')
                stds = x.std('events')
                x = (x-means)/stds

                power_all.append(x)
        
            except Exception as e:
                print(e)
                continue
        
        power_all = special_concat(power_all)
        outf = 'scratch/mtpower/%s/%s_enc_zpower.pkl' % (subject, subject)
        # outf = '/scratch/liyuxuan/ltpFR2/mtpower/%s/%s_enc_zpower.pkl' % (subject, subject)
        pkl.dump(power_all, open(outf, 'wb'))
    
    if task == 'ret':
        
        files = glob.glob('scratch/mtpower/%s/*_ret_logmtpowerts*.pkl' % subject)
        # files = glob.glob('/scratch/liyuxuan/ltpFR2/mtpower/%s/*_ret_logmtpowerts*.pkl' % subject)
        
        power_all = []
        for f in files:

            try:
                x = pkl.load(open(f,'rb'))
                
                # preserve events with no overlap w/ vocalization only
                x = filter_recalls(x)
                
                # select from time -0.25s
                x = x.sel(time=-250)
                
                means = x.mean('events')
                stds = x.std('events')
                x = (x-means)/stds

                power_all.append(x)

            except Exception as e:
                print(e)
                continue
        
        power_all = special_concat(power_all)
        outf = 'scratch/mtpower/%s/%s_ret_zpower.pkl' % (subject, subject)
        # outf = '/scratch/liyuxuan/ltpFR2/mtpower/%s/%s_ret_zpower.pkl' % (subject, subject)
        pkl.dump(power_all, open(outf, 'wb'))

In [None]:
for s in FR2_valid_subjects:
    print(s,)
    concat_power(s)

LTP093
LTP106
LTP115
LTP117
LTP123
LTP133
LTP138
LTP207
LTP210
LTP228
LTP229
LTP236
LTP246
LTP249
LTP250
LTP251
LTP258
LTP259
LTP265
LTP269
LTP273
LTP278
LTP279
LTP280
LTP283
LTP285
LTP287
LTP293
LTP296
LTP297
LTP299
LTP301
LTP302
LTP303
LTP304
LTP305
LTP306
LTP307
LTP310
LTP311
LTP312
LTP316
LTP317
LTP318
LTP321
LTP322
LTP323
LTP324
LTP325
LTP326
LTP327
LTP328
LTP329
LTP331
LTP334
LTP336
LTP339
LTP341
LTP342
LTP343
LTP344
LTP346
LTP347
LTP348
LTP349
LTP354
LTP355
LTP357
LTP360
LTP361
LTP362
LTP364
LTP365
LTP366
LTP367
LTP371
LTP372
LTP373
LTP374
LTP376
LTP377
LTP385
LTP386
LTP387
LTP389
LTP390
LTP391
LTP393


In [10]:
for s in ['LTP357', 'LTP360', 'LTP361', 'LTP365', 'LTP366']:
    print(s,)
    concat_power(s)

LTP357
LTP360
LTP361
LTP365
LTP366
