# Predict female behav vars from surrogate neural population activity for pre-specified sets of strains

In [None]:
%matplotlib inline
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
from sklearn import linear_model
import sys

from disp import set_plot
from my_torch import skl_fit_ridge

cc = np.concatenate


TARG_BHV = 'MTN'
TWDWS = [.03, 1, 60]
TARGS = [f'{TARG_BHV}_MN_{twdw}' for twdw in TWDWS]

PFX_BHV = f'data/simple/behav_xtd/behav_xtd'

NNRN = 224
NRL_MDLS = [
    'ma',  # mult-adapt neural model
    
    'lnma',  # lin-nonlin neural model w MA-derived filters
    # 'lnr',  # linear-nonlinear model (fit w ridge regr)
#     'lnr_relu',  
#     'lnr_relu_flex',
    'lnma_tweaked',  # lin-nonlin neural model w double-exp filter optimized from MA fit
    
    'linma',  # linear neural model w MA-derived filters
    
#     'linr',  # linear neural model (fit w ridge regr)
    
#     'ma_ind_ta',
#     'ma_sia'
    
#     'lin',  # linear neural model
#     'ln',  # linear-nonlinear model
]

PTRAIN = .8
NSPLIT = 30

ALPHA = 10

NSPLIT_SAVE_Y_HAT = 5

MASK_PFX = 'data/simple/masks/mask'

if MASK_PFX:
    FSAVE_SFX = f'{TARG_BHV.lower()}_ridge_alpha_{ALPHA}_masked'
else:
    FSAVE_SFX = f'{TARG_BHV.lower()}_ridge_alpha_{ALPHA}'

In [None]:
# SPECIFY STRAINS USED
FSTRAIN = 'data/simple/strains.csv'

STRAINS_ALL = [
    ['NM91'],
    ['NM91', 'ZH23'],
    ['NM91', 'ZH23', 'CarM03'],
    ['NM91', 'ZH23', 'CarM03', 'ZW109'],
]

In [None]:
for STRAINS in STRAINS_ALL:
    STRAIN_KEY = '_'.join(STRAINS).lower()
    
    MSTRAINS = [(pd.read_csv(FSTRAIN)['STRAIN'] == strain) for strain in STRAINS]
    MSTRAIN = np.any(MSTRAINS, axis=0)
    ISTRAIN = MSTRAIN.nonzero()[0]

    NTR = len(ISTRAIN)
    NTRAIN = int(round(PTRAIN*NTR))
    
    print(f'STRAINS: {STRAINS} ({NTR} trials)')

    for nrl_mdl in NRL_MDLS:
        print(f'\nMODEL: {nrl_mdl}')
        pfx_nrl = f'data/simple/mlv/neur_basic/baker_{nrl_mdl}/mlv_baker_{nrl_mdl}'

        r_cols_use = [f'R_{inrn}' for inrn in range(NNRN)]

        fsave = f'data/simple/mlv/neur_basic/by_strain/baker_{nrl_mdl}_{STRAIN_KEY}_{FSAVE_SFX}.npy'

        # fit regression models
        rslts = skl_fit_ridge(
            pfxs=[pfx_nrl, PFX_BHV],
            cols_x=r_cols_use,
            targs=TARGS,
            itr_all=ISTRAIN, 
            ntrain=NTRAIN,
            nsplit=NSPLIT,
            mask_pfx=MASK_PFX,
            return_y=np.arange(NSPLIT_SAVE_Y_HAT),
            alpha=ALPHA)

        # save r2, weights, and example predictions
        save_data = {
            'r2_train': {targ: np.array([rslt.r2_train[targ] for rslt in rslts]) for targ in TARGS},
            'r2_test': {targ: np.array([rslt.r2_test[targ] for rslt in rslts]) for targ in TARGS},

            'w': {targ: np.array([rslt.w[targ] for rslt in rslts]) for targ in TARGS},
            'bias': {targ: np.array([rslt.bias[targ] for rslt in rslts]) for targ in TARGS},

            'ys_train': {targ: [rslt.ys_train[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},
            'ys_test': {targ: [rslt.ys_test[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},

            'y_hats_train': {targ: [rslt.y_hats_train[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},
            'y_hats_test': {targ: [rslt.y_hats_test[targ] for rslt in rslts if rslt.ys_train] for targ in TARGS},

            'targs': TARGS,
            'alpha': ALPHA,

            'ntr': NTR,
            'ntrain': NTRAIN,
            'nsplit': NSPLIT,

            'nr': len(r_cols_use)
        }

        np.save(fsave, np.array([save_data]))