In [1]:
import warnings

import pandas as pd
import numpy as np
import scipy.io as sio
from scipy.stats.mstats import zscore

import scipy.signal # to use signal.hilbert
import scipy.fftpack # to use fftpack.next_fast_len

from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

from ptsa.data.filters import ButterworthFilter
from ptsa.data.filters import MorletWaveletFilter
from ptsa.data.filters import ResampleFilter
from ptsa.data.timeseries import TimeSeries

from cmlreaders import CMLReader, get_data_index

In [2]:
pd.set_option('display.max_columns', None)

In [3]:
EXPERIMENT = 'TH1'
EVENTS_TYPE = 'CHEST'
REL_START=-2000
REL_STOP=3500
BUFFER=2000
FREQ_BAND=[1,3]

In [4]:
def subject_id(subject, montage):
    return (subject if montage == 0
                else f"{subject}_{int(montage)}")

In [5]:
def get_all_events(**subject_dict):
    
    df = get_data_index('r1')
    df = df[(df['experiment']==subject_dict['experiment'])
                 & (df['subject']==subject_dict['subject'])
                 & (df['montage']==subject_dict['montage'])
                ]
    original_sessions = df['original_session'].values
    sessions = df['session'].values

    all_events = []
    try:
        for session in sessions:
            reader = CMLReader(**subject_dict, session=session)
            all_events.append(reader.load('events'))
            
    except FileNotFoundError:
        for session in original_sessions:
            reader = CMLReader(**subject_dict, session=session)
            all_events.append(reader.load('events'))
            
    all_events = pd.concat(all_events)
    all_events.index = range(len(all_events))
    
    all_events = all_events[all_events['eegfile']!= '']
    
    if len(all_events['eegfile'].unique())>len(all_events['session'].unique()):
        bad_events = []
        for iloc, (i, row) in enumerate(all_events[['eegfile', 'session']].iterrows()):
            if row['eegfile'] != all_events.iloc[iloc-1]['eegfile']:
                if row['session'] == all_events.iloc[iloc-1]['session']:
                    bad_events.append(i)
        all_events = all_events[~all_events.index.isin(bad_events)]
    
    return all_events

In [6]:
def load_matlab_contacts(subject, montage, load_type='contacts'):
    """ This is more consistent than cmlreaders version which breaks
        sometimes for no reason. Its also a lot more obvious what
        is going on whereas with cmlreaders its kinda a mystery :)
    """
    subj_str = subject_id(subject, montage)
    if load_type=='contacts':
        load_type = 'monopol'
    elif load_type=='pairs':
        load_type='bipol'
    else:
        raise ValueError(f"load type must be 'pairs' or 'contacts': {load} is invalid load_type.'")
        
    path = f'/data/eeg/{subj_str}/tal/{subj_str}_talLocs_database_{load_type}.mat'
    contacts = pd.DataFrame(sio.loadmat(path, squeeze_me=True)['talStruct'])
    
    if subject=='R1059J' and montage==1:
        # this contact is a duplicate of contact 114 
        contacts = contacts[contacts['channel'] != 113]
        contacts.index = range(len(contacts))
    
    return contacts

In [7]:
def load_contacts(load_type='contacts', **subject_dict):
    
    reader = CMLReader(**subject_dict)
    contacts = reader.load(load_type)
    
    matlab_contacts = load_matlab_contacts(
        subject_dict['subject'], subject_dict['montage'])
    for i in range(1,6):
        contacts[f'Loc{i}'] = matlab_contacts[f'Loc{i}']
    
    if not 'stein.region' in contacts:
        contacts['stein.region'] = np.nan
        
        if 'locTag' in matlab_contacts:
            loctags =  np.array([t if len(t)>0 else np.nan
                        for t in matlab_contacts['locTag']])
            contacts['stein.region'] = loctags

    contacts['stein.hemi'] = [np.nan if pd.isnull(t) or t == 'nan'
                              else t.split(' ')[0]
                              for t in contacts['stein.region']]
    contacts['stein.name'] = [np.nan if pd.isnull(t) or t == 'nan'
                                else ' '.join(t.split(' ')[1:])
                                for t in contacts['stein.region']]
    
    return contacts

In [8]:
def get_region_contacts(region, rcalc, contacts):
    return contacts[contacts[rcalc]==region]['contact'].values

In [9]:
def make_events_first_dim(ts, event_dim_str='event'):
    """
    Transposes a TimeSeries object to have the events dimension first. Returns transposed object.
    From jfmiller's miller_ecog_tools.
    
    Parameters
    ----------
    ts: TimeSeries
        A PTSA TimeSeries object
    event_dim_str: str
        the name of the event dimension
    Returns
    -------
    TimeSeries
        A transposed version of the orginal timeseries
    """

    # if events is already the first dim, do nothing
    if ts.dims[0] == event_dim_str:
        return ts

    # make sure events is the first dim because I think it is better that way
    ev_dim = np.where(np.array(ts.dims) == event_dim_str)[0]
    new_dim_order = np.hstack([ev_dim, np.setdiff1d(range(ts.ndim), ev_dim)])
    ts = ts.transpose(*np.array(ts.dims)[new_dim_order])
    return ts


In [10]:
def load_eeg(events, contacts, 
             rel_start_ms, rel_stop_ms, buf_ms=BUFFER,
             noise_freq=[58., 62.], resample_freq=None,
             pass_band=None, do_average_ref=True,
             **subject_dict,
            ):
    
    if buf_ms is not None:
        start = rel_start_ms - buf_ms
        stop = rel_stop_ms + buf_ms
    
    reader = CMLReader(**subject_dict)
    elec_scheme = load_contacts(**subject_dict)
    
    loaded = False
    while not loaded:
        try:
            eeg = reader.load_eeg(
                events, scheme=elec_scheme,
                rel_start=start, rel_stop=stop, 
            ).to_ptsa()
            loaded=True
        except KeyError as ke:
            bad_contact = int(str(ke).replace("'", ''))
            if bad_contact in contacts:
                raise ke
            elec_scheme = elec_scheme[elec_scheme['contact'] != bad_contact]
    
    # now auto cast to float32 to help with memory issues with high sample rate data
    eeg.data = eeg.data.astype('float32')

    if do_average_ref:
        # compute average reference by subracting the mean across channels
        eeg = eeg - eeg.mean(dim='channel')
    
    
    eeg = make_events_first_dim(eeg)
    # filter channels to only desired contacts
    contact_locs = [elec_scheme[elec_scheme['contact']==c].iloc[0].name
                    for c in contacts]
    eeg = eeg[:,  contact_locs, :]
        
    # filter line noise
    if noise_freq is not None:
        b_filter = ButterworthFilter(eeg, noise_freq, filt_type='stop', order=4)
        filt = b_filter.filter()
        # I'm not sure why, but for subj R1241J this is necessary:
        filt['event'] = eeg['event']
        eeg[:] = filt

    # resample if desired. Note: can be a bit slow especially if have a lot of eeg data
    if resample_freq is not None:
        eeg_resamp = []
        for this_chan in range(eeg.shape[1]):
            r_filter = ResampleFilter(eeg[:, this_chan:this_chan + 1], resample_freq)
            eeg_resamp.append(r_filter.filter())
        coords = {x: eeg[x] for x in eeg.coords.keys()}
        coords['time'] = eeg_resamp[0]['time']
        coords['samplerate'] = resample_freq
        dims = eeg.dims
        eeg = TimeSeries.create(np.concatenate(eeg_resamp, axis=1),
                                resample_freq, coords=coords,
                                dims=dims)

    # do band pass if desired.
    if pass_band is not None:
        eeg = ButterworthFilter(eeg, pass_band, filt_type='pass', order=4).filter()
        
    eeg = make_events_first_dim(eeg)
    
    return eeg

In [11]:
def compute_hilbert(eeg):
    """ Computes power using hilbert transform. Based off code
        the youtube series I reference in the comments of the code.
    """
    # Using modified hilbert function for speed (see https://github.com/scipy/scipy/issues/6324)
    fast_hilbert = lambda x: scipy.signal.hilbert(
        x, scipy.fftpack.next_fast_len(x.shape[-1])
    )[:, :, :x.shape[-1]]
    # abs(hilbert) gives the amplitute, sqaure(amplitute) gives power.
    # see: https://www.youtube.com/watch?v=VyLU8hlhI-I&t=421s
    amp = np.abs(fast_hilbert(eeg))
    hilbert = np.square(amp)
    # ("TimeSeries" is a class from ptsa)
    hilbert_pow = TimeSeries(data=hilbert, 
                             coords=eeg.coords, 
                             dims=eeg.dims
                            )
    return hilbert_pow

In [12]:
def zscore_eeg(eeg_pow):
    
    z_pow = zscore(eeg_pow, axis=eeg_pow.get_axis_num('time'))
    z_pow = TimeSeries(data=z_pow, coords=eeg_pow.coords,
                       dims=eeg_pow.dims)  
    return z_pow

In [13]:
def compute_power(eeg, buf_ms=BUFFER):
    hilbert = compute_hilbert(eeg).remove_buffer(buf_ms/1000.)
    z_pow = zscore_eeg(hilbert)
    return z_pow

In [14]:
def error_fill(xs, ys, err, color, label, axes=None):
    
    if axes == None:
        plotter = plt
    else:
        plotter = axes
    
    plotter.fill_between(xs, ys-err, ys+err,
                     alpha=.4, color=color)
    plotter.plot(xs, ys, label=label, color=color)

In [15]:
def get_roi_locs(subjects):

    warnings.filterwarnings('ignore')

    rois = ['Left EC', 'Right EC', 'Left CA1', 'Right CA1']

    roi_locs = {}
    for roi in rois:
        roi_locs[roi] = []

    for i, subject in subjects.iterrows():
        contacts = load_contacts(**subject)
        for roi in rois:
            if roi in contacts['stein.region'].unique():
                roi_locs[roi].append(True)
            else:
                roi_locs[roi].append(False)

    for roi in rois:
        roi_locs[roi] = np.array(roi_locs[roi])

    return roi_locs

In [16]:
def get_subject_erps(region='Right CA1',
                       events_type=EVENTS_TYPE, freq_band=FREQ_BAND,
                       rel_start_ms=REL_START, rel_stop_ms=REL_STOP,
                       buf_ms=BUFFER, **subject_dict
                      ):
    
    # get events
    events = get_all_events(**subject_dict)
    events = events[events['type']==events_type]
    events.index = range(len(events))
    
    # get elecs
    all_elecs = load_contacts(**subject)
    contacts = get_region_contacts(region, 'stein.region', all_elecs)
    
    # get powers
    eeg = load_eeg(events, contacts, **subject_dict,
                   rel_start_ms=-2000, rel_stop_ms=3500,
                   do_average_ref=False, pass_band=freq_band,
                   buf_ms=buf_ms,
                  )
    
    powers = compute_power(eeg, buf_ms)
    
    # split by recalled, not recalled, and empty chests
    recalled_locs = events['recalled'].values
    empty_locs = (events['item_name']=='').values
    not_recalled_locs = ~(recalled_locs|empty_locs)
    
    # get powers accross events
    rec_power = powers[recalled_locs].mean(dim='event')
    nrec_power = powers[not_recalled_locs].mean(dim='event')
    empty_power = powers[empty_locs].mean(dim='event')
    
    # get erps accross channels
    rec_erp = rec_power.mean(dim='channel')
    rec_sem = rec_power.std(dim='channel') / np.sqrt(rec_power.channel.size)

    nrec_erp = nrec_power.mean(dim='channel')
    nrec_sem = nrec_power.std(dim='channel') / np.sqrt(nrec_power.channel.size)

    empty_erp = empty_power.mean(dim='channel')
    empty_sem = empty_power.std(dim='channel') / np.sqrt(nrec_power.channel.size)

    # done
    return pd.DataFrame({'rec_erp': rec_erp, 'rec_sem': rec_sem,
                         'nrec_erp': nrec_erp, 'nrec_sem': nrec_sem,
                         'empty_erp': empty_erp, 'empty_sem': empty_sem, 
                         'samplerate': rec_power.samplerate.values[()]
                        })

In [17]:
def plot_erp(erps, freq_band=FREQ_BAND,
             rel_start_ms=REL_START, rel_stop_ms=REL_STOP, axes=None):
    
    samplerate = erps['samplerate'].iloc[0]
    
    xs = np.arange(rel_start_ms, rel_stop_ms, 1000./samplerate)
    
    rec_color = 'firebrick'
    nrec_color = 'steelblue'
    empty_color = 'gray'
    
    if axes is None:
        axes = plt.subplot()

    error_fill(xs, erps['rec_erp'], erps['rec_sem'], color=rec_color, label='recalled', axes=axes)
    error_fill(xs, erps['nrec_erp'], erps['nrec_sem'], color=nrec_color, label='not recalled', axes=axes)
    error_fill(xs, erps['empty_erp'], erps['empty_sem'], color=empty_color, label='empty', axes=axes)

    axes.axvline(0, color='k')
    axes.axvline(1500, color='gray', linestyle=':') # item stops being displayed
    axes.axhline(0, color='k', linestyle='--')

    axes.set_xlabel('time (ms)')
    axes.set_ylabel(fr'Z(power) $\emdash$ [{freq_band[0]}-{freq_band[1]}] Hz')

    #axes.legend()

In [21]:
region='Right CA1'

In [22]:
df = get_data_index('r1')
df = df[df['experiment']==EXPERIMENT]
subjects = df[['subject', 'montage', 'localization', 'experiment']].drop_duplicates()
subjects.index = range(len(subjects))

In [None]:
roi_locs = get_roi_locs(subjects)
roi_subjects = subjects[roi_locs[region]]
roi_subjects.index = range(len(roi_subjects))

In [None]:
ncols = int(np.sqrt(len(roi_subjects)))
nrows = int(np.ceil(len(roi_subjects)/ncols))

fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(15,15))

all_recs = []
all_nrecs = []
all_emptys = []

for i, subject in tqdm(roi_subjects.iterrows(), total=len(roi_subjects)):
    print(i, end=', ')
    
    ax = axes[i%nrows, int(i/nrows)]
    try:
        erps = get_subject_erps(**subject, region=region, freq_band=[6,8])
        
        all_recs.append(erps['rec_erp'])
        all_nrecs.append(erps['nrec_erp'])
        all_emptys.append(erps['empty_erp'])
        
        plot_erp(erps, axes=ax)
    except IndexError:
        continue

fig.tight_layout()