In [None]:
# Try to decode which feature was selected per-trial based on firing rates of neurons
# experiment with ranges of firing rates around fixation (selection) time

%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as patches

import pandas as pd
import scipy.stats
from lfp_tools import (
    general as lfp_general,
    startup as lfp_startup,
    development as lfp_development,
    analysis as lfp_analysis
)
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
import s3fs
import utils.behavioral_utils as behavioral_utils
import utils.spike_utils as spike_utils
import utils.classifier_utils as classifier_utils
import utils.visualization_utils as visualization_utils
import utils.io_utils as io_utils
from trial_splitters.random_splitter import RandomSplitter
from trial_splitters.block_splitter import BlockSplitter
from models.model_wrapper import ModelWrapper
from models.multinomial_logistic_regressor import (
    MultinomialLogisticRegressor, 
    NormedMultinomialLogisticRegressor,
    NormedDropoutMultinomialLogisticRegressor,
)
from models.multi_layer import MultiLayer

from models.trainer import Trainer
import pickle

from sklearn.cluster import KMeans
import plotly.express as px
from itertools import accumulate

import torch

matplotlib.rcParams['figure.dpi'] = 150

In [None]:
species = 'nhp'
subject = 'SA'
exp = 'WCST'
session = 20180802  # this is the session for which there are spikes at the moment. 
pre_interval = 1300
post_interval = 2000

feature_dims = ["Color", "Shape", "Pattern"]

In [None]:
# grab behavioral data, spike data, trial numbers. 
fs = s3fs.S3FileSystem()
behavior_file = spike_general.get_behavior_path(subject, session)
behavior_data = pd.read_csv(fs.open(behavior_file))
valid_beh = behavior_data[behavior_data.Response.isin(["Correct", "Incorrect"])]   
valid_beh = valid_beh[valid_beh.TrialNumber >= 57]

In [None]:
firing_rates_50 = pd.read_pickle(fs.open(f"l2l.pqz317.scratch/firing_rates_{pre_interval}_fb_{post_interval}_50_bins.pickle"))
firing_rates_50 = firing_rates_50[firing_rates_50.TrialNumber >= 57]

In [None]:
response_codes = [200 if res == "Correct" else 206 for res in valid_beh.Response.values]

In [None]:
explore_exploit = lfp_development.get_exploration(np.array(response_codes), 1)
valid_beh["explore"] = explore_exploit

In [None]:
explore_trials = valid_beh[valid_beh.explore == 1]
exploit_trials = valid_beh[valid_beh.explore == 0]

print(f"Number of correct trials in explore state: {len(explore_trials[explore_trials.Response == 'Correct'])}")
print(f"Number of incorrect trials in explore state: {len(explore_trials[explore_trials.Response == 'Incorrect'])}")
print(f"Number of correct trials in exploit state: {len(exploit_trials[exploit_trials.Response == 'Correct'])}")
print(f"Number of incorrect trials in exploit state: {len(exploit_trials[exploit_trials.Response == 'Incorrect'])}")


In [None]:
rng = np.random.default_rng(seed=42)
num_samples = 285 # smallest common number
explore_cor_sampled = explore_trials[explore_trials.Response == 'Correct'].sample(num_samples, random_state=rng)
explore_inc_sampled = explore_trials[explore_trials.Response == 'Incorrect'].sample(num_samples, random_state=rng)
exploit_cor_sampled = exploit_trials[exploit_trials.Response == 'Correct'].sample(num_samples, random_state=rng)
exploit_inc_sampled = exploit_trials[exploit_trials.Response == 'Incorrect'].sample(num_samples, random_state=rng)


### Look at specific neuron firing

In [None]:
neuron_id = 52

unit_fr = firing_rates_50[firing_rates_50.UnitID == neuron_id]

explore_cor_fr = unit_fr[unit_fr.TrialNumber.isin(explore_cor_sampled.TrialNumber)]
explore_inc_fr = unit_fr[unit_fr.TrialNumber.isin(explore_inc_sampled.TrialNumber)]
exploit_cor_fr = unit_fr[unit_fr.TrialNumber.isin(exploit_cor_sampled.TrialNumber)]
exploit_inc_fr = unit_fr[unit_fr.TrialNumber.isin(exploit_inc_sampled.TrialNumber)]

trans_explore_cor = np.stack(explore_cor_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())
trans_explore_inc = np.stack(explore_inc_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())
trans_exploit_cor = np.stack(exploit_cor_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())
trans_exploit_inc = np.stack(exploit_inc_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

visualization_utils.visualize_accuracy_across_time_bins(
    trans_explore_cor,
    1.3, 2, .05,
    axs[0],
    label="Explore Correct",
    right_align=True
)
visualization_utils.visualize_accuracy_across_time_bins(
    trans_explore_inc,
    1.3, 2, .05,
    axs[0],
    label="Explore Incorrect",
    right_align=True
)
axs[0].axvspan(-0.8, 0, alpha=0.3, color='gray')
axs[0].axvline(0.098, alpha=0.3, color='gray', linestyle='dashed')
axs[0].set_xlabel("Time Relative to Feedback (s)")
axs[0].set_ylabel("Unit 52 Spike Counts")
axs[0].legend()

visualization_utils.visualize_accuracy_across_time_bins(
    trans_exploit_cor,
    1.3, 2, .05,
    axs[1],
    label="Exploit Correct",
    right_align=True
)
visualization_utils.visualize_accuracy_across_time_bins(
    trans_exploit_inc,
    1.3, 2, .05,
    axs[1],
    label="Exploit Incorrect",
    right_align=True
)
axs[1].axvspan(-0.8, 0, alpha=0.3, color='gray')
axs[1].axvline(0.098, alpha=0.3, color='gray', linestyle='dashed')
axs[1].set_xlabel("Time Relative to Feedback (s)")
axs[1].set_ylabel("Unit 52 Spike Counts")
axs[1].legend()

### One big for loop

In [None]:
for neuron_id in range(59):
    unit_fr = firing_rates_50[firing_rates_50.UnitID == neuron_id]

    explore_cor_fr = unit_fr[unit_fr.TrialNumber.isin(explore_cor_sampled.TrialNumber)]
    explore_inc_fr = unit_fr[unit_fr.TrialNumber.isin(explore_inc_sampled.TrialNumber)]
    exploit_cor_fr = unit_fr[unit_fr.TrialNumber.isin(exploit_cor_sampled.TrialNumber)]
    exploit_inc_fr = unit_fr[unit_fr.TrialNumber.isin(exploit_inc_sampled.TrialNumber)]

    trans_explore_cor = np.stack(explore_cor_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())
    trans_explore_inc = np.stack(explore_inc_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())
    trans_exploit_cor = np.stack(exploit_cor_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())
    trans_exploit_inc = np.stack(exploit_inc_fr.groupby(["TimeBins"], as_index=False).apply(lambda x: x["SpikeCounts"].to_numpy()).to_numpy())

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    visualization_utils.visualize_accuracy_across_time_bins(
        trans_explore_cor,
        1.3, 2, .05,
        axs[0],
        label="Explore Correct",
        right_align=True
    )
    visualization_utils.visualize_accuracy_across_time_bins(
        trans_explore_inc,
        1.3, 2, .05,
        axs[0],
        label="Explore Incorrect",
        right_align=True
    )
    axs[0].axvspan(-0.8, 0, alpha=0.3, color='gray')
    axs[0].axvline(0.098, alpha=0.3, color='gray', linestyle='dashed')
    axs[0].set_xlabel("Time Relative to Feedback (s)")
    axs[0].set_ylabel(f"Unit {neuron_id} Spike Counts")
    axs[0].legend()

    visualization_utils.visualize_accuracy_across_time_bins(
        trans_exploit_cor,
        1.3, 2, .05,
        axs[1],
        label="Exploit Correct",
        right_align=True
    )
    visualization_utils.visualize_accuracy_across_time_bins(
        trans_exploit_inc,
        1.3, 2, .05,
        axs[1],
        label="Exploit Incorrect",
        right_align=True
    )
    axs[1].axvspan(-0.8, 0, alpha=0.3, color='gray')
    axs[1].axvline(0.098, alpha=0.3, color='gray', linestyle='dashed')
    axs[1].set_xlabel("Time Relative to Feedback (s)")
    axs[1].set_ylabel(f"Unit {neuron_id} Spike Counts")
    axs[1].legend()

    plt.savefig(f"../data/responses/{neuron_id}.png")