In [25]:
# Try to decode which feature was selected per-trial based on firing rates of neurons
# experiment with ranges of firing rates around fixation (selection) time

%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import scipy.stats
from lfp_tools import (
    general as lfp_general,
    startup as lfp_startup,
    development as lfp_development,
    analysis as lfp_analysis
)
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import s3fs
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.io_utils as io_utils
from trial_splitters.random_splitter import RandomSplitter
from trial_splitters.block_splitter import BlockSplitter
from sklearn import svm
from sklearn.linear_model import LogisticRegression
import pickle

from models.value_model import ValueModel

from models.trainer import Trainer

import torch

matplotlib.rcParams['figure.dpi'] = 150


species = 'nhp'
subject = 'SA'
exp = 'WCST'
session = 20180802  # this is the session for which there are spikes at the moment. 
pre_interval = 1300
post_interval = 1500

feature_dims = ["Color", "Shape", "Pattern"]

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# grab behavioral data, spike data, trial numbers. 
fs = s3fs.S3FileSystem()
behavior_file = spike_general.get_behavior_path(subject, session)
behavior_data = pd.read_csv(fs.open(behavior_file))
valid_beh = behavior_data[behavior_data.Response.isin(["Correct", "Incorrect"])]   

In [4]:
feature_selections = pd.read_pickle(fs.open("l2l.pqz317.scratch/feature_selections.pickle"))
firing_rates = pd.read_pickle(fs.open("l2l.pqz317.scratch/firing_rates_1300_fb_1500_100_bins.pickle"))

In [8]:
FEATURES = ['CIRCLE', 'SQUARE', 'STAR', 'TRIANGLE', 'CYAN', 'GREEN', 'MAGENTA', 'YELLOW', 'ESCHER', 'POLKADOT', 'RIPPLE', 'SWIRL']
cards = np.empty((len(valid_beh), 4, 3))

for card_idx in range(4):
    for dim_idx, dim in enumerate(["Color", "Shape", "Pattern"]):
        features = valid_beh[f"Item{card_idx}{dim}"]
        features_idx = features.apply(lambda f: FEATURES.index(f))
        cards[:, card_idx, dim_idx] = features_idx


        

In [17]:
# trying to play around with value model, see if it works:
inputs_for_bin = firing_rates[np.isclose(firing_rates["TimeBins"], 0.1)]
renamed = inputs_for_bin.rename(columns={"SpikeCounts": "Value"})
model_inputs = classifier_utils.transform_to_input_data(renamed)


In [24]:
model = ValueModel(59, 12)

In [47]:
res = model(torch.Tensor(model_inputs), cards)
print(res.shape)

torch.Size([1749, 12])
torch.Size([1749, 4])
