### Ask if rule can be decoded during fb period if we only look at correct last 8 trials

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import scipy.stats
from lfp_tools import (
    general as lfp_general,
    startup as lfp_startup,
    development as lfp_development,
    analysis as lfp_analysis
)
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import s3fs
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.io_utils as io_utils
from trial_splitters.random_splitter import RandomSplitter
from trial_splitters.block_splitter import BlockSplitter
from trial_splitters.kfold_splitter import KFoldSplitter
from trial_splitters.feature_block_splitter import FeatureBlockSplitter
from sklearn import svm
from sklearn.linear_model import LogisticRegression, LinearRegression
from models.value_models import ValueNormedModel, ValueNormedDropoutModel
import pickle

from models.multinomial_logistic_regressor import NormedDropoutMultinomialLogisticRegressor
from models.model_wrapper import ModelWrapper, ModelWrapperLinearRegression

from models.trainer import Trainer
from sklearn.cluster import KMeans

import torch
from torch import nn

import plotly.express as px
import matplotlib.patches as patches

import scipy.stats as sci_stats
import scipy

from itertools import accumulate


matplotlib.rcParams['figure.dpi'] = 300
matplotlib.rcParams.update({'font.size': 20})



species = 'nhp'
subject = 'SA'
exp = 'WCST'
session = 20180802  # this is the session for which there are spikes at the moment. 

feature_dims = ["Color", "Shape", "Pattern"]

pre_interval = 1300
post_interval = 1500

In [2]:
# grab behavioral data, spike data, trial numbers. 
fs = s3fs.S3FileSystem()
behavior_file = spike_general.get_behavior_path(subject, session)
behavior_data = pd.read_csv(fs.open(behavior_file))
# only look at corrects for this one
valid_beh = behavior_data[behavior_data.Response.isin(["Correct"])]   
shuffled_card_idxs = behavioral_utils.get_shuffled_card_idxs(valid_beh)
valid_beh = valid_beh[valid_beh.TrialNumber >= 57]

In [7]:
# get rules with more than 5 blocks associated with them
num_rules = valid_beh.groupby(["CurrentRule"]).apply(lambda x: len(x.BlockNumber.unique()))
rules_more_than_five = num_rules[num_rules > 5].index.values




last_n = 8

def label_trials(block_group, last_n):
    block_len = len(block_group)
    block_group["TrialUntilRuleChange"] = block_len - block_group["TrialAfterRuleChange"]
    last_eight = block_group[block_group["TrialUntilRuleChange"] <= last_n]
    return last_eight
block_groups = valid_beh.groupby(["BlockNumber"], as_index=False)
only_last_n = block_groups.apply(label_trials, last_n).reset_index()

valid_beh = valid_beh[
    (valid_beh.TrialNumber.isin(only_last_n.TrialNumber)) & 
    (valid_beh.CurrentRule.isin(rules_more_than_five)) 
]

In [4]:
frs = pd.read_pickle(fs.open("l2l.pqz317.scratch/firing_rates_1300_fb_1500_100_bins.pickle"))
feature_selections = pd.read_pickle(fs.open("l2l.pqz317.scratch/feature_selections.pickle"))

In [8]:
# only look at trial numbers of last eights
frs = frs[frs.TrialNumber.isin(valid_beh.TrialNumber)]

In [None]:
num_neurons = len(frs.UnitID.unique())
classes = valid_beh["CurrentRule"].unique()
init_params = {"n_inputs": num_neurons, "p_dropout": 0.5, "n_classes": len(classes)}
trainer = Trainer(learning_rate=0.01)
wrapped = ModelWrapper(NormedDropoutMultinomialLogisticRegressor, init_params, trainer, classes)

mode = "SpikeCounts"

# prep data for classification
inputs = frs.rename(columns={mode: "Value"})
labels = valid_beh.rename(columns={"CurrentRule": "Feature"})

splitter = RandomSplitter(valid_beh.TrialNumber.unique(), 20, 0.2)

outputs = classifier_utils.evaluate_classifiers_by_time_bins(
    wrapped, inputs, labels, np.arange(0, 2.8, 0.1), splitter
)
io_utils.save_model_outputs(
    fs, 
    f"fb_rule_last_eights_normed_dropout", 
    f"{pre_interval}_fb_{post_interval}",
    "random_split",
    outputs
)