In [18]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
import utils.pca_utils as pca_utils
from constants.glm_constants import *
from constants.behavioral_constants import *
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)
from scipy import stats
import warnings
from scipy.ndimage import gaussian_filter1d
import seaborn as sns
import plotly.express as px


warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
EVENT = "FixationOnCross"  # event in behavior to align on
PRE_INTERVAL = 500   # time in ms before event
POST_INTERVAL = 500  # time in ms after event
INTERVAL_SIZE = 50  # size of interval in ms
SESS_SPIKES_PATH = "/data/patrick_res/firing_rates/{sess_name}_firing_rates_{pre_interval}_{event}_{post_interval}_{interval_size}_bins_1_smooth.pickle"

In [12]:
def load_session_data(row, shuffle_idx=None, use_next_trial_entropy=False):
    sess_name = row.session_name

    behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=sess_name)
    beh = pd.read_csv(behavior_path)
    valid_beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(valid_beh)
    valid_beh = pd.merge(valid_beh, feature_selections, on="TrialNumber", how="inner")
    beh = behavioral_utils.get_feature_values_per_session(sess_name, valid_beh)
    beh = behavioral_utils.get_max_feature_value(beh)
    beh = behavioral_utils.calc_feature_probs(beh)
    beh = behavioral_utils.calc_feature_value_entropy(beh)
    beh = behavioral_utils.calc_confidence(beh, num_bins=3, quantize_bins=3)
    if use_next_trial_entropy:
        beh["ConfidenceBin"] = beh["ConfidenceBin"].shift(-1)
        beh = beh[~beh["ConfidenceBin"].isna()]
        beh["ConfidenceBin"] = beh["ConfidenceBin"].astype(int)

    # shift TrialNumbers by some random amount
    if shuffle_idx is not None: 
        beh = behavioral_utils.shuffle_beh_by_shift(beh, buffer=25, seed=shuffle_idx)

    spikes_path = SESS_SPIKES_PATH.format(
        sess_name=sess_name, 
        pre_interval=PRE_INTERVAL, 
        event=EVENT, 
        post_interval=POST_INTERVAL, 
        interval_size=INTERVAL_SIZE
    )
    frs = pd.read_pickle(spikes_path)
    frs["PseudoUnitID"] = int(row.session_name) * 100 + frs["UnitID"]
    merged = pd.merge(beh, frs, on="TrialNumber")
    return merged

In [13]:
valid_sess = pd.read_pickle(SESSIONS_PATH)
units = spike_utils.get_unit_positions(valid_sess)
res = valid_sess.apply(lambda x: load_session_data(x), axis=1)
res = res.dropna()
all_trials = pd.concat(res.values)

In [32]:
units.structure_level2.unique()

array(['lateral_and_ventral_pallium (LVPal)',
       'lateral_prefrontal_cortex (lat_PFC)',
       'primary_visual_cortex (V1)', 'anterior_cingulate_gyrus (ACgG)',
       'posterior_medial_cortex (PMC)', 'orbital_frontal_cortex (OFC)',
       'unknown', 'basal_ganglia (BG)', 'inferior_temporal_cortex (ITC)',
       'motor_cortex (motor)', 'preoptic_complex (POC)', 'amygdala (Amy)',
       'extrastriate_visual_areas_2-4 (V2-V4)', 'medial_pallium (MPal)',
       'thalamus (Thal)', 'inferior_parietal_lobule (IPL)',
       'superior_parietal_lobule (SPL)',
       'floor_of_the_lateral_sulcus (floor_of_ls)',
       'medial_temporal_lobe (MTL)'], dtype=object)

In [36]:
# region_units = units[units.structure_level2 == "medial_pallium (MPal)"].PseudoUnitID
region_units = units.PseudoUnitID.unique()
unit_trials = all_trials[all_trials.PseudoUnitID.isin(region_units)]

In [34]:
transformed_df, pca = pca_utils.project_conditioned_firing_rates(unit_trials, "ConfidenceBin")

In [38]:
conf_map = {0: "low", 1: "medium", 2: "high"}
transformed_df["ConfidenceBin"] = transformed_df.ConfidenceBin.map(conf_map)

In [39]:
fig = px.scatter_3d(
    transformed_df, 
    x='PC1', y='PC2', z='PC3',
    color='ConfidenceBin', 
    opacity=1,
    size='TimeBins',
    # color_discrete_sequence=colors
)
fig.show()