### Setup

In [1]:
# Try to decode which feature was selected per-trial based on firing rates of neurons
# experiment with ranges of firing rates around fixation (selection) time

%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import scipy.stats
from lfp_tools import (
    general as lfp_general,
    startup as lfp_startup,
    development as lfp_development,
    analysis as lfp_analysis
)
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import s3fs
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.io_utils as io_utils
from trial_splitters.random_splitter import RandomSplitter
from trial_splitters.block_splitter import BlockSplitter
from sklearn import svm
from sklearn.linear_model import LogisticRegression
import pickle

from models.value_models import ValueLinearModel
from models.model_wrapper import ModelWrapper

from models.trainer import Trainer

import torch

import plotly.express as px


matplotlib.rcParams['figure.dpi'] = 150


In [2]:
# grab behavioral data, spike data, trial numbers. 
fs = s3fs.S3FileSystem()

### Look at the example patient, examine which session has the most trials

In [10]:
subject = "sub-IR84"
session = "sess-1" # example session with spikes

beh_path = f"human-lfp/wcst-preprocessed/rawdata/{subject}/{session}/behavior/{subject}-{session}-beh.csv"
beh = pd.read_csv(fs.open(beh_path)) 


In [11]:
beh

Unnamed: 0,bmp_table_4,bmp_table_2,bmp_table_3,bmp_table_1,trials_thisRepN,trials_thisTrialN,trials_thisN,trials_thisIndex,key_resp_3_keys,key_resp_3_rt,...,rule,response_time,ans_correctness,date,frameRate,expName,session,participant,Var22,rule_shift
0,SBL.bmp,TYP.bmp,CGS.bmp,QMR.bmp,0,0,0,0,,,...,Y,1.113291,0,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Y,no
1,SGP.bmp,QBL.bmp,TMS.bmp,CYR.bmp,0,1,1,1,,,...,Y,1.697558,0,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Y,no
2,CMS.bmp,QGR.bmp,SYL.bmp,TBP.bmp,0,2,2,2,,,...,Y,1.780264,1,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Y,no
3,SGL.bmp,TBR.bmp,CMP.bmp,QYS.bmp,0,3,3,3,,,...,Y,1.047041,1,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Y,no
4,QGP.bmp,SMR.bmp,TYL.bmp,CBS.bmp,0,4,4,4,,,...,Y,0.648967,1,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Y,no
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
184,SGS.bmp,TMP.bmp,QYR.bmp,CBL.bmp,0,195,195,195,,,...,Q,0.615304,0,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Q,no
185,CGS.bmp,TMP.bmp,QBL.bmp,SYR.bmp,0,196,196,196,,,...,Q,0.397784,1,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Q,no
186,TYL.bmp,QBP.bmp,SMR.bmp,CGS.bmp,0,197,197,197,,,...,Q,1.229608,0,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Q,no
187,QGR.bmp,CBL.bmp,SYP.bmp,TMS.bmp,0,198,198,198,,,...,Q,0.697868,0,2018_Oct_24_1514,60.024014,wcst6,1,IR84,Q,no


In [12]:
beh.columns

Index(['bmp_table_4', 'bmp_table_2', 'bmp_table_3', 'bmp_table_1',
       'trials_thisRepN', 'trials_thisTrialN', 'trials_thisN',
       'trials_thisIndex', 'key_resp_3_keys', 'key_resp_3_rt',
       'key_resp_2_keys', 'key_resp_2_rt', 'rule', 'response_time',
       'ans_correctness', 'date', 'frameRate', 'expName', 'session',
       'participant', 'Var22', 'rule_shift'],
      dtype='object')

In [13]:
# generate BlockNumber column
block_num = 0
block_nums = []
for _, row in beh.iterrows():
    if row.rule_shift == "yes":
        block_num += 1
    block_nums.append(block_num)
beh["BlockNumber"] = block_nums

In [14]:
beh[["BlockNumber", "rule", "ans_correctness", "rule_shift"]]

Unnamed: 0,BlockNumber,rule,ans_correctness,rule_shift
0,0,Y,0,no
1,0,Y,0,no
2,0,Y,1,no
3,0,Y,1,no
4,0,Y,1,no
...,...,...,...,...
184,9,Q,0,no
185,9,Q,1,no
186,9,Q,0,no
187,9,Q,0,no


In [17]:
beh[beh.rule == 'S'][["key_resp_2_keys", "rule", "ans_correctness", "bmp_table_4", "bmp_table_3", "bmp_table_2", "bmp_table_1"]]

Unnamed: 0,key_resp_2_keys,rule,ans_correctness,bmp_table_4,bmp_table_3,bmp_table_2,bmp_table_1
7,up,S,0,SMR.bmp,CGS.bmp,TBP.bmp,QYL.bmp
8,down,S,0,SGL.bmp,CBP.bmp,TYS.bmp,QMR.bmp
9,right,S,0,SYP.bmp,CMR.bmp,TBS.bmp,QGL.bmp
10,right,S,0,QYS.bmp,TBR.bmp,SMP.bmp,CGL.bmp
11,left,S,0,QGS.bmp,SBR.bmp,TMP.bmp,CYL.bmp
12,left,S,1,SGR.bmp,TML.bmp,QBS.bmp,CYP.bmp
13,left,S,1,SGP.bmp,TYS.bmp,CMR.bmp,QBL.bmp
14,up,S,0,CYL.bmp,TMR.bmp,SBP.bmp,QGS.bmp
15,left,S,0,QGR.bmp,SYP.bmp,TBL.bmp,CMS.bmp
16,up,S,0,QBS.bmp,CGL.bmp,SMR.bmp,TYP.bmp


In [None]:
# generate BlockNumber column
block_num = 0
block_nums = []
for _, row in beh.iterrows():
    if row.rule_shift == "yes":
        block_num += 1
    block_nums.append(block_num)
beh["BlockNumber"] = block_nums


# The rule S could be in 2 different dimensions, assign S in last dim to be Z
for idx, row in beh.iterrows():
    for card_pos in range(1, 5):
        column_name = f"bmp_table_{card_pos}"
        card_features = row[column_name]
        if card_features[2] == "S":
            chars = list(card_features)
            chars[2] = "Z"
            beh.at[idx, column_name] = "".join(chars)
    if row.rule_shift == "yes":
        block_num += 1
    block_nums.append(block_num)
beh["BlockNumber"] = block_nums






# figure out how many blocks there are (count rule_shift)

# for each block, if there's a S as the rule, find the correctly answered trial
# look at selected key (key_resp_2_keys)
# correspond that with card features bmp_table_*
# view position of S. If S position is dim2, switch rule to Z
