### Another shot at dimensionality reduction techniques
Want to try PCA again with the following features: 
- Try on either HC or OFC only cells, small number (19 in HC, 18 in OFC)
- Condition on one selected feature at a time
- Group trials into 3 groups: 
  - A: high feature val, high confidence
  - B: low feature val, high confidence
  - C: low feature val, low confidence
Also, will want to try: 
- 50ms time bins, smoothed with 50ms std Gaussian

### Load Data, Imports

In [46]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.pseudo_utils as pseudo_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.behavioral_utils as behavioral_utils
from utils.session_data import SessionData
import utils.io_utils as io_utils
from utils.constants import *
import json

from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)

import matplotlib.pyplot as plt
import matplotlib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
# the output directory to store the data
OUTPUT_DIR = "/data/patrick_res/pseudo"
# path to a dataframe of sessions to analyze
# SESSIONS_PATH = "/data/patrick_scratch/multi_sess/valid_sessions.pickle"
SESSIONS_PATH = "/data/patrick_res/sessions/valid_sessions_rpe.pickle"
# path for each session, specifying behavior
SESS_BEHAVIOR_PATH = "/data/rawdata/sub-SA/sess-{sess_name}/behavior/sub-SA_sess-{sess_name}_object_features.csv"
# path for each session, for spikes that have been pre-aligned to event time and binned. 
SESS_SPIKES_PATH = "/data/patrick_res/firing_rates/{sess_name}_firing_rates_{pre_interval}_{event}_{post_interval}_{interval_size}_bins_1_smooth.pickle"

FEATURE_DIMS = ["Color", "Shape", "Pattern"]

### Per session, label trials
Need confidence values, as well as feature values

In [64]:
def get_labels_for_session(session, feat):
    behavior_path = SESS_BEHAVIOR_PATH.format(sess_name=session)

    beh = pd.read_csv(behavior_path)
    valid_beh = behavioral_utils.get_valid_trials(beh)
    feature_selections = behavioral_utils.get_selection_features(valid_beh)
    valid_beh_merged = pd.merge(valid_beh, feature_selections, on="TrialNumber", how="inner")
    feat_dim = FEATURE_TO_DIM[feat]
    valid_beh_merged = valid_beh_merged[valid_beh_merged[feat_dim] == feat]
    valid_beh_vals = behavioral_utils.get_feature_values_per_session(session, valid_beh_merged)
    valid_beh_vals_conf = behavioral_utils.get_rpes_per_session(session, valid_beh_vals)
    med_conf = np.median(valid_beh_vals_conf["Prob_FE"].to_numpy())
    def assign_conf(row, med):
        row["Conf"] = "high" if row["Prob_FE"] > med else "low"
        return row
    valid_beh_vals_conf = valid_beh_vals_conf.apply(lambda row: assign_conf(row, med_conf), axis=1)

    # feat_dim = FEATURE_TO_DIM[feat]
    # selected_feat_beh = valid_beh_vals_conf[valid_beh_vals_conf[feat_dim] == feat]
    valid_beh_vals_conf["MaxFeatMatches"] = valid_beh_vals_conf.MaxFeat == feat
    valid_beh_vals_conf["Session"] = session
    return valid_beh_vals_conf

In [65]:
feature = "CYAN"
valid_sessions = pd.read_pickle(SESSIONS_PATH)
res = pd.concat(valid_sessions.apply(lambda row: get_labels_for_session(row.session_name, feature), axis=1).values)

In [67]:
res.groupby(["Session", "Conf", "MaxFeatMatches"]).count()[:50]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,TrialNumber,BlockNumber,TrialAfterRuleChange,TaskInterrupt,ConditionNumber,Response,ItemChosen,TrialType,CurrentRule,LastRule,...,SWIRL,MaxFeat,trial_y,fb,Prob_FE,Prob_FD,Prob_FRL,RPE_FE,RPE_FD,RPE_FRL
Session,Conf,MaxFeatMatches,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
20180705,high,False,81,81,81,1,81,81,81,81,81,72,...,81,81,81,81,81,81,81,81,81,81
20180705,high,True,80,80,80,0,80,80,80,80,80,80,...,80,80,80,80,80,80,80,80,80,80
20180705,low,False,106,106,106,1,106,106,106,106,106,90,...,106,106,106,106,106,106,106,106,106,106
20180705,low,True,56,56,56,0,56,56,56,56,56,55,...,56,56,56,56,56,56,56,56,56,56
20180709,high,False,97,97,97,0,97,97,97,97,97,83,...,97,97,97,97,97,97,97,97,97,97
20180709,high,True,32,32,32,0,32,32,32,32,32,32,...,32,32,32,32,32,32,32,32,32,32
20180709,low,False,91,91,91,0,91,91,91,91,91,73,...,91,91,91,91,91,91,91,91,91,91
20180709,low,True,38,38,38,0,38,38,38,38,38,37,...,38,38,38,38,38,38,38,38,38,38
20180712,high,False,42,42,42,0,42,42,42,42,42,41,...,42,42,42,42,42,42,42,42,42,42
20180712,high,True,14,14,14,0,14,14,14,14,14,3,...,14,14,14,14,14,14,14,14,14,14
