COGS 4290 RSA

David Halpern 4/4/23

In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import scipy as sp
import xarray as xr
import seaborn as sns
import cmlreaders as cml
import cmldask.CMLDask as da
import os
import traceback
from ptsa.data.filters import ResampleFilter, ButterworthFilter, MorletWaveletFilter
from ptsa.data.timeseries import TimeSeries

In [None]:
subs = ['R1045E', 'R1102P', 'R1108J', 'R1141T', 'R1144E', 'R1157C',
       'R1192C', 'R1202M', 'R1226D', 'R1236J', 'R1269E', 'R1277J',
       'R1291M', 'R1310J', 'R1328E', 'R1330D', 'R1337E', 'R1351M',
       'R1354E', 'R1361C', 'R1375C', 'R1383J', 'R1389J', 'R1390M',
       'R1395M', 'R1401J', 'R1403N', 'R1412M', 'R1465D', 'R1468J',
       'R1474T', 'R1476J', 'R1477J', 'R1482J', 'R1486J', 'R1490T',
       'R1497T', 'R1501J', 'R1515T', 'R1525J', 'R1527J', 'R1530J',
       'R1536J', 'R1541T']

In [None]:
ix = cml.get_data_index()
exp = 'catFR1'
exp_ix = ix.query('experiment == @exp and subject == @subs')

# Preprocess data

We first load the data and preprocess into power at different frequency bands both during encoding (300ms to 1300 ms after the word appears on screen) and right before retrieval (900 to 100ms before vocalization). In order to make signals comparable across encoding and retrieval, we standardize both of them by data from during the 10s countdown periods that preceded each list. This means that all of the data is relative to the countdown period (e.g. there is .5 standard deviations more power in the 3 Hz band than the average of the countdown period).

In [None]:
def compute_features(row, 
                     settings_path='/home1/djhalp/ieeg_rsa/catFR1_test_encoding.pkl', 
                     save_path='/scratch/djh/rsa_class',
                     overwrite=False):
    """
    Compute log-transformed powers, averaged over time and stacked as (frequency, channel) to create features
    These can later be normalized along the event axis.
    """
    settings = da.Settings.Load(settings_path)
    ix = cml.get_data_index()
    #sub_ix = ix.query('experiment == @settings.experiment and subject == @subject')
    ev_type = settings.type
    save_fp = save_path + '/' + row.subject + '_' + str(row.session) + '_' + settings.type + '_feats.h5'
    print(save_fp)
    if (not os.path.exists(save_fp)) or overwrite:
        print("Reading subject:", row.subject, "session:", row.session)
        # intialize data reader, load words events and buffered eeg epochs
        r = cml.CMLReader(subject=row.subject, 
                          experiment=row.experiment, session=row.session,
                          localization=row.localization, montage=row.montage)
        
        if settings.type == 'COUNTDOWN_START':
            countdown_events = pd.read_csv(save_path+'/'+exp+'_countdown_evs.csv')
            evs = countdown_events.query('subject == @row.subject and session == @row.session')
        else:
            evs = r.load('task_events')
        evs = evs.query('type == @settings.type and eegoffset != -1')
        evs = evs[evs['list'] > 0]
        evs['category'] = evs['category'].str.lower()
        print(evs)
        scheme = r.load("pairs")
        if settings.type == "REC_WORD":
            # get inter-retreival times since previous recall
            evs['pirt'] = evs.groupby(['session', 'list'])['rectime'].diff().fillna(
                evs['rectime'])
            evs['repeat'] = evs.duplicated(subset=['session', 'list', 'item_name'])
            evs['outpos'] = evs.groupby(['subject', 'session', 'list']).cumcount()
            # only include recalls at least 1500 ms away and no repeats
            evs = evs.query('pirt > 1500 and repeat == 0')
            eeg = r.load_eeg(evs, 
                         rel_start=settings.rel_start, 
                        rel_stop=settings.rel_stop,
                        scheme=scheme).to_ptsa()
            # select relevant channels
            eeg['time'] = eeg['time'] / 1000 # PTSA time scale is in seconds instead of ms
            eeg = eeg.add_mirror_buffer(settings.buffer_time/1000)
            eeg['time'] = eeg['time'] * 1000
        else:
            eeg = r.load_eeg(events=evs,
                              rel_start=-settings.buffer_time+settings.rel_start,
                              rel_stop=settings.buffer_time+settings.rel_stop,
                              scheme=scheme).to_ptsa()

        # centering signal within event
        # reduce edge effects / ringing in later processing steps:
        eeg = eeg.astype(float) - eeg.mean('time')
        # filter out line noise at 60 Hz
        eeg = ButterworthFilter(filt_type='stop', 
                                freq_range=[58, 62], 
                                order=4).filter(timeseries=eeg)
        pows = MorletWaveletFilter(freqs=settings.freqs,
                                   width=settings.width,
                                   output='power',
                                   cpus=4).filter(timeseries=eeg)
        del eeg
        #resample to resamplerate
        print('resample rate', settings.resameplerate)
        pows = xr.ufuncs.log10(pows)
        pows = ResampleFilter(resamplerate=settings.resameplerate).filter(
            timeseries=pows)
        pows = pows.remove_buffer(settings.buffer_time / 1000)
        if settings.type != "COUNTDOWN_START":
            pows = pows.mean('time') #average over time
        # reshape as events x features
        pows = pows.stack(features=("channel", "frequency"))
        if 'stim_params' in pows.indexes['event'].names:
            pows = pows.assign_coords(
                {"event": pows.indexes['event'].droplevel('stim_params')}
            )
        pows = pows.assign_attrs(settings.__dict__)
        pows.to_hdf(save_fp)
        return pows

In [None]:
settings = da.Settings()
# freqs = np.unique(np.round(np.logspace(np.log10(1), np.log10(300), 17)))
settings.freqs = np.logspace(np.log10(3),np.log10(180), 8)
settings.width = 6
settings.rel_start = 0
settings.rel_stop = 10000
settings.resameplerate = 1
settings.experiment = 'catFR1'
settings.type = 'COUNTDOWN_START'
# settings.buffer_time = (settings.width / 2) * (1000 / min(settings.freqs))
settings.buffer_time = 1000
settings.Save("catFR1_countdown_preprocess.pkl")

In [None]:
settings = da.Settings()
settings.width = 6
settings.rel_start = 300
settings.rel_stop = 1300
settings.experiment = 'catFR1'
settings.freqs = np.logspace(np.log10(3),np.log10(180), 8)
settings.type = 'WORD'
settings.buffer_time = (settings.width / 2) * (1000 / min(settings.freqs))
settings.freqs = np.logspace(np.log10(3),np.log10(180), 8)
settings.resameplerate = 500
settings.Save("catFR1_encoding_preprocess.pkl")

In [None]:
settings = da.Settings()
# freqs = np.unique(np.round(np.logspace(np.log10(1), np.log10(300), 17)))
settings.freqs = np.logspace(np.log10(3),np.log10(180), 8)
settings.width = 6
settings.experiment = 'catFR1'
settings.resameplerate = 500
settings.type = 'REC_WORD'
settings.rel_start = -900
settings.rel_stop = -100
settings.buffer_time = 760
settings.Save("catFR1_retrieval_preprocess.pkl")

In [None]:
client = da.new_dask_client("class_preprocessing",
                            "55GB",
                            max_n_jobs=10,
                            log_directory='/scratch/djh/rsa_class/log_directory')

In [None]:
futures = client.map(
    compute_features, 
    list(exp_ix.itertuples()), 
    overwrite=True,
    settings_path='/home1/djhalp/ieeg_rsa/catFR1_countdown_preprocess.pkl',
    save_path='/scratch/djh/rsa_class/')

In [None]:
len(da.filter_futures(futures, status='pending'))

In [None]:
# wait(futures)
errors = da.get_exceptions(futures, range(len(futures)))
errors

We got a couple errors but its not a huge deal, we'll just ignore these sessions for now. If we want to investigate the error though we can look at the traceback

In [None]:
traceback.print_tb(errors.traceback_obj.iloc[0])

Now we'll process the encoding data

In [None]:
futures = client.map(
    compute_features, 
    list(exp_ix.itertuples()), 
    overwrite=False,
    settings_path='/home1/djhalp/ieeg_rsa/catFR1_encoding_preprocess.pkl',
    save_path='/scratch/djh/rsa_class/')

In [None]:
len(da.filter_futures(futures, status='pending'))

In [None]:
# wait(futures)
errors = da.get_exceptions(futures, range(len(futures)))
errors

In [None]:
futures = client.map(
    compute_features, 
    list(exp_ix.itertuples()), 
    overwrite=False,
    settings_path='/home1/djhalp/ieeg_rsa/catFR1_retrieval_preprocess.pkl',
    save_path='/scratch/djh/rsa_class/')

In [None]:
len(da.filter_futures(futures, status='pending'))

In [None]:
errors = da.get_exceptions(futures, range(len(futures)))
errors['exception']

# Compute RSA

Now we have to compute the RSA. This basically involves loading the outputs of the `compute_features` function above and using the xarray `corr` function which will compute the correlation matrix between two data arrays and hold on to all the relevant information for us. There are two tricky things going on here. One is that we need to normalize the encoding and retrieval time features by the countdown features. we have a function called `normalize_features` to do that. We also need to change the names of the event information so that they don't match. If they do, the `corr` function will assume they are referring to the same informaiton and will not compute a correlation between them. In order to distinguish the item_name at retrieval and at encoding, we append the event `type` onto the name of each column. If we are computing the correlation between two events of the same `type` (e.g. each encoding event to other encoding events), we add a 2 on the end to distinguish them.

In [None]:
settings = da.Settings()
settings.experiment = "catFR1"
settings.encoding_type = "WORD"
settings.comparison_type = "WORD"
settings.countdown_normalize = 1
settings.Save("catFR1_encoding_rsa.pkl")

In [None]:
settings.comparison_type = "REC_WORD"
settings.Save("catFR1_retrieval_rsa.pkl")

In [None]:
def normalize_features(pows, save_path, countdown_normalize=True):
    subject, session = pows.event.subject.values[0], pows.event.session.values[0]
    if countdown_normalize:
        countdown_fp = save_path + '/' + subject + '_' + str(session) + '_COUNTDOWN_START_feats.h5'
        countdown_pows = TimeSeries.from_hdf(countdown_fp)
        countdown_pows['samplerate'] = pows['samplerate']
        pows = (pows - countdown_pows.mean(['event', 'time'])) / countdown_pows.std(['event', 'time']) 
    else:
        pows = pows.reduce(func=sp.stats.zscore, dim='event', keep_attrs=True, ddof=1)
    return pows

def set_event_names(pows, period_type, col='event', copy=False):
    if copy:
        pows = pows.copy()
    events_mi = pows[col].to_index()
    events_mi = events_mi.rename(
        [name+'_'+period_type for name in events_mi.names])
    pows[col] = events_mi
    pows = pows.rename({col: col+'_'+period_type})
    return pows

def compute_rsa(row, overwrite=False,
                settings_path='/home1/djhalp/ieeg_rsa/catFR1_encoding_rsa.pkl', 
                save_path='/scratch/djh/rsa_class/'):
    """
    Compute rsa between two periods.
    """
    settings = da.Settings.Load(settings_path)
    ix = cml.get_data_index()
    print("Processing subject:", row.subject, "session:", row.session)
    save_fp = save_path + '/' + row.subject + '_' + str(row.session) + '_' + settings.encoding_type + '_' + settings.comparison_type + '_rsa.h5'
    if (not os.path.exists(save_fp)) or overwrite:
        # intialize data reader, load words events and buffered eeg epochs
        encoding_fp = save_path + '/' + row.subject + '_' + str(row.session) + '_' + settings.encoding_type + '_feats.h5'
        
        #need to rename event index to keep track of things for correlation matrix
        encoding_pows = TimeSeries.from_hdf(encoding_fp)
#         print('enc_pows', encoding_pows)
        encoding_pows = normalize_features(encoding_pows, save_path, countdown_normalize=settings.countdown_normalize)
        if settings.comparison_type == settings.encoding_type:
            comparison_pows = set_event_names(encoding_pows, settings.encoding_type+'2', col='event', copy=True)
        else:
            comparison_fp = save_path + '/' + row.subject + '_' + str(row.session) + '_' + settings.comparison_type + '_feats.h5'
            comparison_pows = TimeSeries.from_hdf(comparison_fp)
            comparison_pows = normalize_features(comparison_pows, save_path, countdown_normalize=settings.countdown_normalize)
            comparison_pows = set_event_names(comparison_pows, settings.comparison_type, col='event')
        encoding_pows = set_event_names(encoding_pows, settings.encoding_type, col='event')
        
        
#         _cov_corr(encoding_pows, comparison_pows, dim='features', method="corr")
        
        corr_arr = xr.corr(encoding_pows, comparison_pows, dim='features')
        print('corr_arr', corr_arr.indexes)
        corr_df = corr_arr.to_dataframe('corr').reset_index()
        corr_df['corr_z'] = np.arctanh(corr_df['corr'])
        corr_df.to_csv(save_fp, index=False)

In [None]:
futures = client.map(
    compute_rsa, 
    list(exp_ix.itertuples()), 
    overwrite=False,
    settings_path='/home1/djhalp/ieeg_rsa/catFR1_encoding_rsa.pkl',
    save_path='/scratch/djh/rsa_class/')

In [None]:
len(da.filter_futures(futures, status='pending'))

In [None]:
errors = da.get_exceptions(futures, range(len(futures)))
errors

In [None]:
futures = client.map(
    compute_rsa, 
    list(exp_ix.itertuples()), 
    overwrite=False,
    settings_path='/home1/djhalp/ieeg_rsa/catFR1_retrieval_rsa.pkl',
    save_path='/scratch/djh/rsa_class/')

In [None]:
len(da.filter_futures(futures, status='pending'))

In [None]:
errors = da.get_exceptions(futures, range(len(futures)))
errors

Finally we just need to aggregate all the RSA matrices together into one giant dataframe for future analysis. For the encoding-to-encoding RSA, we'll also cut things down a bit by enforcing that the `_WORD` events come before `_WORD2`. This is fine since the full correlation matrix is symmetric so we're just taking the lower triangle.

In [None]:
save_path = '/scratch/djh/rsa_class/'
encoding_type = "WORD"
comparison_type = "WORD"
experiment = 'catFR1'
ix = cml.get_data_index()
sub_ix = ix.query('experiment == @experiment and subject == @subs')
rsa_dfs = []

for _, row in sub_ix.iterrows():
    try:
        rsa_fp = save_path + str(row['subject']) + '_' + str(row['session']) + '_' + encoding_type + '_' + comparison_type + '_rsa.h5'
        rsa_df = pd.read_csv(rsa_fp)
        #only need lower triangle
        rsa_df = rsa_df.query('((list_WORD2 > list_WORD) or ' +
                              '((list_WORD2 == list_WORD) and ' +
                              '(serialpos_WORD2 > serialpos_WORD)))'
                             )

        rsa_dfs.append(rsa_df)
    except:
        continue
raw_WORD_rsa_df = pd.concat(rsa_dfs)
raw_WORD_rsa_df.to_csv(save_path + 'raw_WORD_rsa_df.csv', index=False)

In [None]:
raw_WORD_rsa_df['corr'].mean()

In [None]:
save_path = '/scratch/djh/rsa_class/'
encoding_type = "WORD"
comparison_type = "REC_WORD"
experiment = 'catFR1'
ix = cml.get_data_index()
sub_ix = ix.query('experiment == @experiment and subject == @subs')
rsa_dfs = []

for _, row in sub_ix.iterrows():
    try:
        rsa_fp = save_path + str(row['subject']) + '_' + str(row['session']) + '_' + encoding_type + '_' + comparison_type + '_rsa.h5'
        rsa_df = pd.read_csv(rsa_fp)

        rsa_dfs.append(rsa_df)
    except:
        continue
raw_REC_WORD_rsa_df = pd.concat(rsa_dfs)
raw_REC_WORD_rsa_df.to_csv(save_path + 'raw_REC_WORD_rsa_df.csv', index=False)

In [None]:
raw_REC_WORD_rsa_df['corr']

In [None]:
raw_REC_WORD_rsa_df['corr'].mean()