### Play around more with PCA, now with pseudotrial generation used for belief sim, projection updates. 
- Plot examples of pairs of features. Low, High X, High not X
- Form data matrix of [N,  (K* x T x C)]
- Perform PCA, plot averaged trajectories

In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.pseudo_classifier_utils as pseudo_classifier_utils
import utils.classifier_utils as classifier_utils

import utils.io_utils as io_utils

import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import matplotlib
import utils.spike_utils as spike_utils
import utils.subspace_utils as subspace_utils
from trial_splitters.condition_trial_splitter import ConditionTrialSplitter 
from utils.session_data import SessionData
from constants.behavioral_constants import *
from constants.decoding_constants import *
import seaborn as sns
from scripts.pseudo_decoding.belief_partitions.belief_partition_configs import *
import scripts.pseudo_decoding.belief_partitions.belief_partitions_io as belief_partitions_io

import scipy
import argparse
import copy
import itertools

from tqdm import tqdm
from sklearn.decomposition import PCA


In [2]:
def get_pseudo_frs_for_session(session, args, num_pseudo=100):
    # for grabbing behavior and firing rates, use subject-specific arguments
    # for grabbing decoder weights, use general
    sub_args = copy.deepcopy(args)
    sub_args.subject = behavioral_utils.get_sub_for_session(session)
    beh = behavioral_utils.load_behavior_from_args(session, sub_args)
    beh = behavioral_utils.get_belief_partitions_of_pair(beh, sub_args.feat_pair)

    frs = spike_utils.get_frs_from_args(sub_args, session)
    frs["TimeIdx"] = (frs["Time"] * 10).round().astype(int)

    rng = np.random.default_rng()
    feat_x, feat_y = args.feat_pair
    all_pseudo_trials = []
    for cond_idx, cond in enumerate(["Low", f"High {feat_x}", f"High {feat_y}"]):
        sub_beh = beh[beh.BeliefPartition == cond]
        trial_nums = rng.choice(sub_beh.TrialNumber.unique(), num_pseudo)
        cond_offset = cond_idx * num_pseudo
        pseudo_trials = pd.DataFrame({"TrialNumber": trial_nums, "cond": cond, "PseudoTrialNumber": np.arange(num_pseudo) + cond_offset})
        all_pseudo_trials.append(pseudo_trials)
    all_pseudo_trials = pd.concat(all_pseudo_trials)
    pseudo_frs = pd.merge(frs, all_pseudo_trials, on="TrialNumber")
    pseudo_frs["session"] = session
    return pseudo_frs

In [None]:
def compute_pcs(data):
    zscored = spike_utils.zscore_frs(data, group_cols=["PseudoUnitID"], mode="AvgFiringRate")
    flattened = zscored.groupby(["cond", "Time"]).apply(lambda x: x.sort_values(by="PseudoUnitID").ZAvgFiringRate.values).reset_index(name="ZAvgFiringRates")
    np_arr = np.vstack(flattened.ZAvgFiringRates.values)

In [3]:
BOTH_PAIRS_PATH = "/data/patrick_res/sessions/both/pairs_at_least_3blocks_10sess.pickle"
pairs = pd.read_pickle(BOTH_PAIRS_PATH).reset_index(drop=True)
args = argparse.Namespace(
    **BeliefPartitionConfigs()._asdict()
)
args.subject = "both"
args.mode = "pref"
# args.sig_unit_level = "belief_partition_95th"
# args.sig_unit_level = "pref_99th_window_filter_drift"
args.trial_event = "StimOnset"
args.trial_interval = get_trial_interval(args.trial_event)

In [4]:
# circle, square, have 415 units, 14 sessions
pair = pairs.iloc[0]
args.feat_pair = pair.pair


In [61]:
all_res = pd.concat(pd.Series(pair.sessions).apply(lambda x: get_pseudo_frs_for_session(x, args)).values)
# all_res = all_res[all_res.Time < 0]

In [None]:
zscored = spike_utils.zscore_frs(all_res, group_cols=["PseudoUnitID"], mode="FiringRate")
flattened = zscored.groupby(["PseudoTrialNumber", "cond", "TimeIdx"]).apply(lambda x: x.sort_values(by="PseudoUnitID").ZFiringRate.values).reset_index(name="ZFiringRate")
np_arr = np.vstack(flattened.ZFiringRate.values)
pca = PCA(n_components=10)
pca = pca.fit(np_arr)

In [72]:
cond_avg = zscored.groupby(["cond", "TimeIdx", "PseudoUnitID"]).ZFiringRate.mean().reset_index(name="ZFiringRate")
cond_avg_flat = cond_avg.groupby(["cond", "TimeIdx"]).apply(lambda x: x.sort_values(by="PseudoUnitID").ZFiringRate.values).reset_index(name="ZFiringRate")
cond_arr = np.vstack(cond_avg_flat.ZFiringRate.values)



In [None]:
cond_arr.shape

In [74]:
transformed = pca.transform(cond_arr)

In [None]:
transformed.shape

In [76]:
cond_avg_pca_df = pd.DataFrame(data=transformed, columns=[f"PC{i}" for i in range(10)])
cond_avg_pca_res = pd.concat((cond_avg_flat, cond_avg_pca_df), axis=1)

In [None]:
order = cond_avg_pca_res.cond.unique()
colors = ["tab:red", "tab:blue", "tab:green"]

fig, axs = plt.subplots(3, 3, figsize=(10, 8))
for i in range(9):
    ax = axs[i // 3, i % 3]
    sns.lineplot(cond_avg_pca_res, x=f"PC{i}", y=f"PC{i+1}", hue="cond", sort=False, legend=False, ax=ax, hue_order=order, palette=colors)
    
    last_point = cond_avg_pca_res.groupby("cond").apply(lambda x: x.loc[x.TimeIdx.idxmax()])
    sns.scatterplot(last_point, x=f"PC{i}", y=f"PC{i+1}", hue="cond", ax=ax, hue_order=order, palette=colors)
    if i != 0:
        ax.get_legend().remove()
fig.tight_layout()

In [None]:
import plotly.express as px
fig = px.line_3d(cond_avg_pca_res, x="PC2", y="PC3", z="PC4", color='cond')
fig.show()


In [79]:
transformed = pca.transform(np_arr)
pca_df = pd.DataFrame(data=transformed, columns=[f"PC{i}" for i in range(10)])
pca_res = pd.concat((flattened, pca_df), axis=1)

In [89]:
def comp_mean_std(group):
    res = {}
    for i in range(10):
        res[f"PC{i}_mean"] = group[f"PC{i}"].mean()
        res[f"PC{i}_se"] = group[f"PC{i}"].sem()
    return pd.Series(res)


mean_std = pca_res.groupby(["cond", "TimeIdx"]).apply(comp_mean_std).reset_index()

In [None]:
fig = px.line_3d(mean_std, x="PC3_mean", error_x="PC3_se", y="PC4_mean", error_y="PC4_se", z="PC5_mean", error_z="PC5_se", color='cond')
fig.show()