Fit simple models predicting female behavior from surrogate neural activity generated using perturbed versions of Baker et al population fits.

In [None]:
%matplotlib inline
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from sklearn import linear_model
import sys

from disp import set_plot
from my_torch import skl_fit_lin, torch_fit_lin

cc = np.concatenate

NTR = 276
NTRAIN = int(round(.8*NTR))
NSPLIT = 30

FIT_LIN = skl_fit_lin
FIT_KWARGS = {}

TARG = 'FFV'
# TARG = 'FLS'

TAU_RS = [.1, .5, 1, 2, 3, 5, 7, 10, 15, 20, 30]

for tau_r in TAU_RS:
    KEY = f'taufixed_{tau_r}'

    print(KEY)
    
    DF_BAKER_FIT_PARAM = pd.read_csv(f'data/simple/neur/baker_dyn_fit_param.csv')
    R_COLS_USE = [f'R_{ir}' for ir in DF_BAKER_FIT_PARAM['EXPT_ID']]


    PFX_BEHAV = f'data/simple/mlv_c/perturbed_ppln/c_baker_dyn_{KEY}/mlv_c_baker_dyn_{KEY}'
    SAVE_FILE = f'data/simple/mlv_c/perturbed_ppln/c_baker_dyn_{KEY}_{TARG.lower()}_{NTR}_tr.npy'

    # fit regression models
    rgrs = []
    for csplit in range(NSPLIT):
        print(f'Split {csplit}')
        rnd_tr_idxs = np.random.permutation(NTR)
        idxs_train = rnd_tr_idxs[:NTRAIN]
        idxs_test = rnd_tr_idxs[NTRAIN:]
        rgr = FIT_LIN(PFX_BEHAV, R_COLS_USE, TARG, idxs_train, idxs_test, **FIT_KWARGS)
        rgrs.append(rgr)

    # save r2, weights, and example predictions
    save_data = {
        'r2_train': np.array([rgr.r2_train for rgr in rgrs]),
        'r2_test': np.array([rgr.r2_test for rgr in rgrs]),
        'w': np.array([rgr.w for rgr in rgrs]),
        'ys_train': [rgr.ys_train for rgr in rgrs[:2]],
        'y_hats_train': [rgr.y_hats_train for rgr in rgrs[:2]],
        'ys_test': [rgr.ys_test for rgr in rgrs[:2]],
        'y_hats_test': [rgr.y_hats_test for rgr in rgrs[:2]],
        'fit_fn': FIT_LIN.__name__,
        'fit_kwargs': FIT_KWARGS,
        'ntr': NTR,
        'nsplit': NSPLIT,
        'nr': len(R_COLS_USE)
    }

    np.save(SAVE_FILE, np.array([save_data]))

    TRAIN_IDX_PLOT = 0
    TEST_IDX_PLOT = 0

    data = np.load(SAVE_FILE, allow_pickle=True)[0]
    gs = gridspec.GridSpec(3, 3)
    fig = plt.figure(figsize=(12, 10), tight_layout=True)
    axs = [fig.add_subplot(gs[0, 0]), fig.add_subplot(gs[0, 1:]), fig.add_subplot(gs[1, :]), fig.add_subplot(gs[2, :])]

    axs[0].hist(np.transpose([data['r2_train'], data['r2_test']]), bins=30)
    axs[0].legend(['Train', 'Test'])
    set_plot(axs[0], x_label='R2', y_label='# splits', font_size=14)

    axs[1].bar(np.arange(data['w'].shape[1]), np.mean(data['w'], axis=0))
    set_plot(axs[1], x_label='Response ID', y_label='Weight', font_size=14)

    axs[2].plot(np.arange(len(data['ys_train'][0][TRAIN_IDX_PLOT])), data['ys_train'][0][TRAIN_IDX_PLOT], c='k', lw=2)
    axs[2].plot(np.arange(len(data['y_hats_train'][0][TRAIN_IDX_PLOT])), data['y_hats_train'][0][TRAIN_IDX_PLOT], c='r', lw=2)
    axs[2].legend(['True', 'Predicted'])
    set_plot(axs[2], y_lim=(-.5, 1), x_label='Timestep', y_label='FFV', title='Training data', font_size=14)

    axs[3].plot(np.arange(len(data['ys_test'][0][TEST_IDX_PLOT])), data['ys_test'][0][TEST_IDX_PLOT], c='k', lw=2)
    axs[3].plot(np.arange(len(data['y_hats_test'][0][TEST_IDX_PLOT])), data['y_hats_test'][0][TEST_IDX_PLOT], c='r', lw=2)
    axs[3].legend(['True', 'Predicted'])
    set_plot(axs[3], y_lim=(-.5, 1), x_label='Timestep', y_label='FFV', title='Held-out data', font_size=14)