# Predict female behav vars from top PCs of surrogate neural population activity

In [1]:
%matplotlib inline
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from sklearn import linear_model
from sklearn.decomposition import PCA
import sys
import warnings; warnings.filterwarnings("ignore")

from disp import set_plot
from my_torch import skl_fit_ridge

cc = np.concatenate

STRAINS = ['NM91', 'ZH23']
STRAIN_KEY = '_'.join(STRAINS).lower()

PTRAIN = .8
NSPLIT = 30

TARGS = ['CTL']
W_CTL = np.random.randn(224)/np.sqrt(224)

NRL_MDL = 'ma'

ALPHA = 10

NSPLIT_SAVE_Y_HAT = 3

NPCS = np.array([1, 5, 10, 20, 30, 50, 100, 224]) #

ZSCORE = True

In [2]:
PFX_BHV = f'data/simple/behav_xtd/behav_xtd'

FSTRAIN = 'data/simple/strains.csv'
MSTRAINS = [(pd.read_csv(FSTRAIN)['STRAIN'] == strain) for strain in STRAINS]
MSTRAIN = np.any(MSTRAINS, axis=0)
ISTRAIN = MSTRAIN.nonzero()[0]

NTR = len(ISTRAIN)
NTRAIN = int(round(PTRAIN*NTR))

In [3]:
frs = []
y_ctls = []
tr_lens = []

pfx_pre_pca = f'data/simple/mlv/neur_basic/baker_{NRL_MDL}/mlv_baker_{NRL_MDL}'

dfs_tr = [np.load(f'{pfx_pre_pca}_tr_{itr}.npy', allow_pickle=True)[0]['df'] for itr in ISTRAIN]
     
for itr, df_tr in zip(ISTRAIN, dfs_tr):
    r_cols = [f'R_{ir}' for ir in range(224)]
    frs.append(np.array(df_tr[r_cols]))
    tr_lens.append(len(df_tr))
    
    y_ctl = np.array(df_tr[r_cols])@W_CTL + 5*np.random.randn(len(df_tr))
    y_ctls.append(y_ctl)
    
frs = cc(frs, axis=0)

if ZSCORE:
    fr_mn = frs.mean(axis=0)
    fr_sd = frs.std(axis=0)
    
    frs = frs - fr_mn
    frs = frs / fr_sd

if not ZSCORE:
    pfx_nrl = f'data/simple/mlv/neur_basic/baker_{NRL_MDL}_pca_ctl/mlv_baker_{NRL_MDL}_pca_ctl'
else:
    pfx_nrl = f'data/simple/mlv/neur_basic/baker_{NRL_MDL}_zpca_ctl/mlv_baker_{NRL_MDL}_zpca_ctl'

In [4]:
for npc in NPCS:
    print(f'{npc} PCs...')
    # perform PCA on neural data, saving only npc components
    pca = PCA(n_components=npc)
    frs_reduced = pca.fit_transform(frs)
    
    for cpc, expl_var in enumerate(pca.explained_variance_ratio_):
        print(f'Var expl by PC {cpc} = {expl_var:.4f}')
        
    # save reduced trials
    it_starts = cc([[0], np.cumsum(tr_lens)[:-1]])
    it_stops = np.cumsum(tr_lens)
    
    for it_start, it_stop, itr, df_tr, y_ctl in zip(it_starts, it_stops, ISTRAIN, dfs_tr, y_ctls):
        df_tr_pca = df_tr[[col for col in df_tr.columns if not col.startswith('R_')]]
        # add PC cols
        for cpc in range(npc):
            df_tr_pca[f'PC_{cpc}'] = frs_reduced[it_start:it_stop, cpc]
            
        # add control behavior
        df_tr_pca['CTL'] = y_ctl.copy()
            
        # save trial
        np.save(f'{pfx_nrl}_tr_{itr}.npy', np.array([{'df': df_tr_pca}]))
        
    cols_use = [f'PC_{cpc}' for cpc in range(npc)]
    
    fsave = f'data/simple/mlv/neur_basic/baker_{NRL_MDL}_zpca_ctl/baker_{NRL_MDL}_{STRAIN_KEY}_ctl_ridge_alpha_{ALPHA}_{npc}_zpc.npy'
    
    # fit regression models
    rslts = skl_fit_ridge(
        pfxs=[pfx_nrl, PFX_BHV],
        cols_x=cols_use,
        targs=TARGS,
        itr_all=ISTRAIN, 
        ntrain=NTRAIN,
        nsplit=NSPLIT,
        return_y=np.arange(NSPLIT_SAVE_Y_HAT),
        alpha=ALPHA)

    # save r2, weights, and example predictions
    save_data = {
        'r2_train': {targ: np.array([rslt.r2_train[targ] for rslt in rslts]) for targ in TARGS},
        'r2_test': {targ: np.array([rslt.r2_test[targ] for rslt in rslts]) for targ in TARGS},
        
        'w': {targ: np.array([rslt.w[targ] for rslt in rslts]) for targ in TARGS},
        
        'ys_train': {targ: [rslt.ys_train[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},
        'ys_test': {targ: [rslt.ys_test[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},
        
        'y_hats_train': {targ: [rslt.y_hats_train[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},
        'y_hats_test': {targ: [rslt.y_hats_test[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},
        
        'targs': TARGS,
        'alpha': ALPHA,
        
        'ntr': NTR,
        'ntrain': NTRAIN,
        'nsplit': NSPLIT,
        
        'nr': len(cols_use)
    }

    print('')
    for targ in TARGS:
        r2_test = save_data['r2_test'][targ]
        print(f'<{targ} R2> = {np.mean(r2_test)}')
        
    print('')
    
    np.save(fsave, np.array([save_data]))

1 PCs...
Var expl by PC 0 = 0.7275
Loading...

Split 0
Split 1
Split 2
Split 3
Split 4
Split 5
Split 6
Split 7
Split 8
Split 9
Split 10
Split 11
Split 12
Split 13
Split 14
Split 15
Split 16
Split 17
Split 18
Split 19
Split 20
Split 21
Split 22
Split 23
Split 24
Split 25
Split 26
Split 27
Split 28
Split 29
<CTL R2> = 0.04355992512989202

5 PCs...
Var expl by PC 0 = 0.7275
Var expl by PC 1 = 0.1202
Var expl by PC 2 = 0.0825
Var expl by PC 3 = 0.0267
Var expl by PC 4 = 0.0168
Loading...

Split 0
Split 1
Split 2
Split 3
Split 4
Split 5
Split 6
Split 7
Split 8
Split 9
Split 10
Split 11
Split 12
Split 13
Split 14
Split 15
Split 16
Split 17
Split 18
Split 19
Split 20
Split 21
Split 22
Split 23
Split 24
Split 25
Split 26
Split 27
Split 28
Split 29
<CTL R2> = 0.05735818514759329

10 PCs...
Var expl by PC 0 = 0.7275
Var expl by PC 1 = 0.1202
Var expl by PC 2 = 0.0825
Var expl by PC 3 = 0.0267
Var expl by PC 4 = 0.0168
Var expl by PC 5 = 0.0074
Var expl by PC 6 = 0.0046
Var expl by PC 7 = 0.0037
