In [None]:
import mne
import numpy as np

import copy
import pickle
import scipy

import os
os.chdir('..')
from _parameters import *

from IPython.display import clear_output

In [None]:
def open_tfr(s, moment, modal, behavior = None):
    
    # Dict per cond
    tfr = {c: None for c in cond[modal]}

    # Read tfr per cond
    for c in cond[modal]:
        condition = c.replace('/','_')
        if behavior == None:
            tfr_fname = dirs['tfr'] + f'/tfr_multi_{moment}_{condition}_s{str(s)}.h5'
        else:
            tfr_fname = dirs['tfr'] + f'/behavior/tfr_multi_{moment}__{behavior}_{condition}_s{str(s)}.h5'

        # Add to dict
        tfr[c] = mne.time_frequency.read_tfrs(fname=tfr_fname)[0]

    return tfr

In [None]:
def calc_cvsi(tfr, trials_L, trials_R, chan_L, chan_R):

    tfr_cpy = copy.deepcopy(tfr)

    # Left channels
    contra_L = tfr_cpy[trials_R].pick(chan_L).data.mean(0) # contra
    ipsi_L = tfr_cpy[trials_L].pick(chan_L).data.mean(0) # ipsi

    cvsi_L = ((contra_L - ipsi_L) / (contra_L + ipsi_L)) * 100

    tfr_cpy = copy.deepcopy(tfr)
    
    # Right channels
    contra_R = tfr_cpy[trials_L].pick(chan_R).data.mean(0) # contra
    ipsi_R = tfr_cpy[trials_R].pick(chan_R).data.mean(0) # ipsi

    cvsi_R = ((contra_R - ipsi_R) / (contra_R + ipsi_R)) * 100

    # Average
    cvsi = np.mean(np.asarray([cvsi_L, cvsi_R]), 0)

    return cvsi

In [None]:
def get_cvsi(subs, moment):

    # Empty list for each cond
    cvsi_cond = ['vis', 'mot']
    cvsi_dat = {c:[] for c in cvsi_cond}

    # Loop over subs
    for s in subs:

        # Get tfr data
        tfr_vis = open_tfr(s, moment, 'visual')
        tfr_mot = open_tfr(s, moment, 'motor')

        # vis and mot
        cvsi_dat['vis'].append(
            calc_cvsi(tfr_vis, 'itemL', 'itemR', 'PO7', 'PO8')
        )
        cvsi_dat['mot'].append(
            calc_cvsi(tfr_mot, 'respL', 'respR', 'C3', 'C4')
        )

        clear_output(wait = False)

    # Avg over subs
    cvsi_avg = {c:None for c in cvsi_cond}

    for c in cvsi_dat.keys():
        cvsi_avg[c] = np.mean(np.asarray(cvsi_dat[c]),0)

    return cvsi_dat, cvsi_avg

In [None]:
def get_cvsi_samerev(subs, moment):

    # Empty list for each cond
    cvsi_cond = ['vis/same', 'vis/rvrs', 'mot/same', 'mot/rvrs']
    cvsi_dat = {c:[] for c in cvsi_cond}

    # Loop over subs
    for s in subs:

        # Get tfr data
        tfr_vis_samerev = open_tfr(s, moment, 'vis-samerev')
        tfr_mot_samerev = open_tfr(s, moment, 'mot-samerev')

        # calc cvsi per cond
        cvsi_dat['vis/same'].append(
            calc_cvsi(tfr_vis_samerev, 'same/itemL', 'same/itemR', 'PO7', 'PO8')
        )
        cvsi_dat['vis/rvrs'].append(
            calc_cvsi(tfr_vis_samerev, 'rvrs/itemL', 'rvrs/itemR', 'PO7', 'PO8')
        )
        cvsi_dat['mot/same'].append(
            calc_cvsi(tfr_mot_samerev, 'same/respL', 'same/respR', 'C3', 'C4')
        )
        cvsi_dat['mot/rvrs'].append(
            calc_cvsi(tfr_mot_samerev, 'rvrs/respL', 'rvrs/respR', 'C3', 'C4')
        )

        clear_output(wait = False)

    # Avg over subs
    cvsi_avg = {c:None for c in cvsi_cond}

    for c in cvsi_dat.keys():
        cvsi_avg[c] = np.mean(np.asarray(cvsi_dat[c]),0)

    return cvsi_dat, cvsi_avg

In [None]:
def get_cvsi_behavior(subs, moment, behavior):

    # Empty list for each cond
    base_cond = ['vis/same', 'vis/rvrs', 'mot/same', 'mot/rvrs']
    beh_cond = [f'{behavior}_trade_low', f'{behavior}_trade_high']  

    cvsi_cond = []
    for cond in base_cond:
        for beh in beh_cond:
            cvsi_cond.append(cond+'/'+beh)
    
    cvsi_dat = {c:[] for c in cvsi_cond}

    for beh in beh_cond:

        for s in subs:

            # Get tfr data
            tfr_vis_behav = open_tfr(s, moment, 'vis-samerev', beh)
            tfr_mot_behav = open_tfr(s, moment, 'mot-samerev', beh)

            # calc cvsi per cond
            cvsi_dat[f'vis/same/{beh}'].append(
                calc_cvsi(tfr_vis_behav, 'same/itemL', 'same/itemR', 'PO7', 'PO8')
            )
            cvsi_dat[f'vis/rvrs/{beh}'].append(
                calc_cvsi(tfr_vis_behav, 'rvrs/itemL', 'rvrs/itemR', 'PO7', 'PO8')
            )
            cvsi_dat[f'mot/same/{beh}'].append(
                calc_cvsi(tfr_mot_behav, 'same/respL', 'same/respR', 'C3', 'C4')
            )
            cvsi_dat[f'mot/rvrs/{beh}'].append(
                calc_cvsi(tfr_mot_behav, 'rvrs/respL', 'rvrs/respR', 'C3', 'C4')
            )

            clear_output(wait = False)

    # Avg over subs
    cvsi_avg = {c:None for c in cvsi_cond}

    for c in cvsi_dat.keys():
        cvsi_avg[c] = np.mean(np.asarray(cvsi_dat[c]),0)

    return cvsi_dat, cvsi_avg

In [None]:
def get_tc(data, fband: list, frange = [3,40]):

    # Freq band index
    flow, fhigh = [f-frange[0] for f in fband]

    # Get freq avg time-course
    tc = np.asarray(data)

    tc = tc[:, flow:fhigh+1, :]
    tc = np.mean(tc, 1)

    # Gaussian filter
    tc = scipy.ndimage.gaussian_filter1d(tc,10)

    return tc

In [None]:
def run_cvsi(moment, condition='samerev'):
    
    # Calculate cvsi
    if condition == None:
        cvsi_dat, cvsi_avg = get_cvsi(subjects, moment)
    elif condition == 'samerev':
        cvsi_dat, cvsi_avg = get_cvsi_samerev(subjects, moment)
    elif condition == 'DT':
        cvsi_dat, cvsi_avg = get_cvsi_behavior(subjects, 'enc1', 'DT')
    elif condition == 'err':
        cvsi_dat, cvsi_avg = get_cvsi_behavior(subjects, 'enc1', 'err')
    
    # Get time-course data
    cvsi_tc = {c:None for c in cvsi_dat.keys()}

    for c in cvsi_dat.keys():

        if 'vis' in c: fband = [8,12]
        elif 'mot' in c: fband = [13,30]

        cvsi_tc[c] = get_tc(cvsi_dat[c], fband)

    tfr = mne.time_frequency.read_tfrs(dirs['tfr'] + f'/tfr_multi_{moment}_itemL_s1.h5')[0]
    cvsi_tc['time'] = tfr.times

    # Determine file names
    if condition == None:
        cvsi_fname = dirs['cvsi'] + f'/cvsi_{moment}.pkl'
        tc_fname = dirs['cvsi'] + f'/cvsi_tc_{moment}.pkl'
    else:
        cvsi_fname = dirs['cvsi'] + f'/cvsi_{moment}_{condition}.pkl'
        tc_fname = dirs['cvsi'] + f'/cvsi_tc_{moment}_{condition}.pkl'

    # Save cvsi
    cvsi_dat_f = open(cvsi_fname, 'wb')
    pickle.dump([cvsi_dat, cvsi_avg], cvsi_dat_f)
    cvsi_dat_f.close()

    # Save time-courses
    cvsi_tc_f = open(tc_fname, 'wb')
    pickle.dump(cvsi_tc, cvsi_tc_f)
    cvsi_tc_f.close()

In [None]:
# Locked to encoding

run_cvsi('enc1')

In [None]:
# Encoding performance split

run_cvsi('enc1', condition='DT')
run_cvsi('enc1', condition='err')

In [None]:
# Locked to probe (1)

run_cvsi('prob1')

In [None]:
# Locked to probe (2)

run_cvsi('prob2')

In [None]:
# Locked to resp (1)

run_cvsi('resp1')


In [None]:
# Locked to resp (2)

run_cvsi('resp2')