### Take a look at the axes from all the value GLMs, see if they're similar

In [2]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import pandas as pd
import utils.behavioral_utils as behavioral_utils
import utils.information_utils as information_utils
import utils.visualization_utils as visualization_utils
import utils.glm_utils as glm_utils
from matplotlib import pyplot as plt
import utils.spike_utils as spike_utils
from constants.glm_constants import *
from constants.behavioral_constants import *
from spike_tools import (
    general as spike_general,
    analysis as spike_analysis,
)

In [3]:
OUTPUT_DIR = "/data/patrick_res/glm_2"
SESSIONS_PATH = "/data/patrick_res/sessions/valid_sessions_rpe.pickle"
sessions = pd.read_pickle(SESSIONS_PATH)

### Combine all GLM results in one

In [4]:
def get_glm_res(session, feedback_type, residual_str):
    res = pd.read_pickle(os.path.join(OUTPUT_DIR, f"{session}_glm_{feedback_type}_{residual_str}_{MODE}_{INTERVAL_SIZE}_{MODEL}_values.pickle"))
    res["TimeIdxs"] = (res["TimeBins"] * 20).astype(int)
    res["session"] = session
    res["PseudoUnitID"] = int(session) * 100 + res.UnitID
    return res

### Look at norm lengths of weight vectors as a func of time: 

In [10]:
fbs = ["Response", "RPEGroup"]
residual_strs = ["residual_fr", "normal_fr"]
glms = {}
for feedback_type in fbs:
    for residual_str in residual_strs:
        glm_res = pd.concat(sessions.apply(lambda x: get_glm_res(x.session_name, feedback_type, residual_str), axis=1).values)
        glms[f"{feedback_type}_{residual_str}"] = glm_res

In [14]:
glm_res.columns

Index(['UnitID', 'TimeBins', 'score', 'less neg_coef', 'more pos_coef',
       'more neg_coef', 'less pos_coef', 'None_coef', 'YELLOW_coef',
       'CYAN_coef', 'MAGENTA_coef', 'GREEN_coef', 'CIRCLE_coef', 'STAR_coef',
       'SQUARE_coef', 'TRIANGLE_coef', 'POLKADOT_coef', 'SWIRL_coef',
       'ESCHER_coef', 'RIPPLE_coef', 'YELLOW_less neg_coef',
       'CYAN_less neg_coef', 'MAGENTA_less neg_coef', 'MAGENTA_more pos_coef',
       'CYAN_more pos_coef', 'CYAN_more neg_coef', 'GREEN_less neg_coef',
       'YELLOW_more pos_coef', 'GREEN_more pos_coef', 'YELLOW_less pos_coef',
       'MAGENTA_less pos_coef', 'YELLOW_more neg_coef', 'GREEN_less pos_coef',
       'GREEN_more neg_coef', 'CYAN_less pos_coef', 'MAGENTA_more neg_coef',
       'nan_coef', 'CIRCLE_less neg_coef', 'STAR_less neg_coef',
       'SQUARE_more pos_coef', 'CIRCLE_more neg_coef',
       'TRIANGLE_less neg_coef', 'SQUARE_less pos_coef',
       'SQUARE_more neg_coef', 'STAR_more pos_coef', 'STAR_less pos_coef',
       'TRI

### Correlations cared about: 
- Between residual vs normal: Response residual vs. Response normal, RPEGroup residual vs. RPEGroup normal
- Between Response vs. RPEGroup: Response residual vs. RPEGroup residual, Response normal vs. RPEGroup normal

In [9]:
pairs = [
    ("Response_residual_fr", "Response_normal_fr"),
    ("RPEGroup_residual_fr", "RPEGroup_normal_fr"),
    ("Response_residual_fr", "RPEGroup_residual_fr"),
    ("Response_normal_fr", "RPEGroup_normal_fr")
]

# each res is np arr [num_time_bins, num_feats]
num_time_bins = 56
num_feats = 12
pair_res = {}
for (cond_a, cond_b) in pairs:
    corrs = np.empty((num_time_bins, num_feats))
    for time_bin in range(num_time_bins):
        for feat_idx in range(num_feats):
            feat = FEATURES[feat_idx]
            glm_res_a = glms[cond_a]
            glm_res_b = glms[cond_b]
            vec_a = glm_res_a[glm_res_a.TimeIdxs == time_bin][]