In [5]:
# Try to decode which feature was selected per-trial based on firing rates of neurons
# experiment with ranges of firing rates around fixation (selection) time

%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as patches

import pandas as pd
import scipy.stats
from lfp_tools import (
    general as lfp_general,
    startup as lfp_startup,
    development as lfp_development,
    analysis as lfp_analysis
)
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import s3fs
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.io_utils as io_utils
from trial_splitters.random_splitter import RandomSplitter
from trial_splitters.block_splitter import BlockSplitter
from models.model_wrapper import ModelWrapper
from models.multinomial_logistic_regressor import MultinomialLogisticRegressor, NormedMultinomialLogisticRegressor
from models.multi_layer import MultiLayer

from models.trainer import Trainer
import pickle

from sklearn.cluster import KMeans
import plotly.express as px
from itertools import accumulate

import torch

matplotlib.rcParams['figure.dpi'] = 150


species = 'nhp'
subject = 'SA'
exp = 'WCST'
session = 20180802  # this is the session for which there are spikes at the moment. 
pre_interval = 1300
post_interval = 2000

feature_dims = ["Color", "Shape", "Pattern"]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# grab behavioral data, spike data, trial numbers. 
fs = s3fs.S3FileSystem()
behavior_file = spike_general.get_behavior_path(subject, session)
behavior_data = pd.read_csv(fs.open(behavior_file))
valid_beh = behavior_data[behavior_data.Response.isin(["Correct", "Incorrect"])]   
valid_beh = valid_beh[valid_beh.TrialNumber >= 57]

In [8]:
firing_rates = pd.read_pickle(fs.open(f"l2l.pqz317.scratch/firing_rates_{pre_interval}_fb_{post_interval}_100_bins.pickle"))
firing_rates = firing_rates[firing_rates.TrialNumber >= 57]

In [11]:
valid_beh.TrialNumber.to_numpy()


array([  57,   58,   59, ..., 1747, 1748, 1749])

In [4]:
valid_beh.TrialAfterRuleChange

57      0
58      1
59      2
60      3
61      4
       ..
1745    2
1746    3
1747    4
1748    5
1749    6
Name: TrialAfterRuleChange, Length: 1692, dtype: int64

In [9]:
valid_beh.groupby(["TrialAfterRuleChange"], as_index=False)

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x7f4eec707eb0>

In [None]:
pre_interval = 1300
post_interval = 2000

models = np.load(fs.open(f"l2l.pqz317.scratch/fb_models_{pre_interval}_fb_{post_interval}_by_bin_random_split.npy"), allow_pickle=True)

# Look at just time at 0.2s after FB onset, 
model = models[15, 0]
time_bin = 1.5

mode = "SpikeCounts"
inputs = firing_rates.rename(columns={mode: "Value"})
inputs = inputs[np.isclose(inputs["TimeBins"], time_bin)]

x_test = classifier_utils.transform_to_input_data(inputs)
device = "cuda" if torch.cuda.is_available() else "cpu"
x_test = torch.Tensor(x_test).to(device)
y_test = model.model(x_test)

probs = y_test.softmax(dim=1)
dv = torch.log(probs[:, 0] / probs[:, 1])
np_dv = dv.detach().cpu().numpy()

cor_idxs = np.argwhere(cor_inc.to_numpy() == "Correct").squeeze()
inc_idxs = np.argwhere(cor_inc.to_numpy() == "Incorrect").squeeze()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(np_dv[:(183 - 57)], label="Decision Var")
ax.scatter(cor_idxs, np.ones(len(cor_idxs)) * -5, label="True Corrects")
ax.scatter(inc_idxs, np.ones(len(inc_idxs)) * 10, label="True Incorrects")
ax.axvline(95 - 57, color="gray", linestyle="dotted")
ax.axvline(121 - 57, color="gray", linestyle="dotted")
ax.axvline(136 - 57, color="gray", linestyle="dotted")
ax.axhline(0, color="gray", linestyle="dotted")
ax.legend()