In [None]:
%load_ext autoreload
%autoreload 2

import os
from os.path import join as pjoin
from datetime import datetime
import pingouin as pg
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import numpy as np
from pingouin import ttest

from bmp_config import path_data, envcode2env
from bmp_behav_proc import *

# Variability (compare with Tan)

## corr ES and variance (and other statistics)

### fixed histlen across subjects

In [None]:
load = 1
if load:
    fnf_fhl = pjoin(path_data,'dfcs_fixhistlen.pkl')
    print(fnf_fhl)
    print( str(datetime.fromtimestamp(os.stat(fnf_fhl).st_mtime)))
    dfcs_fixhistlen = pd.read_pickle(fnf_fhl )
else:
    fnf = pjoin(path_data,'df_all_multi_tsz__.pkl.zip')
    print(fnf)
    print( str(datetime.fromtimestamp(os.stat(fnf).st_mtime)))
    df_all_multi_tsz = pd.read_pickle(fnf)
    df = df_all_multi_tsz.query('trial_shift_size == 1 and trial_group_col_calc == "trialwe" '
                            ' and retention_factor_s == "0.924"').copy().sort_values(['subject','trials'])
    df,dfall,ES_thr,envv,pert = addBehavCols2(df);
    dfc = df.copy()
    dfcs,dfcs_fixhistlen,dfcs_fixhistlen_untrunc,histlens  = addWindowStatCols(dfc, ES_thr, 
                                        varn0s = ['error','error_pscadj','error_pscadj_abs'])
    dfcs_fixhistlen, me_pct_excl = truncLargeStats(dfcs_fixhistlen_untrunc, histlens, 5.)

In [None]:
import pingouin as pg
dfcs_fixhistlen['environment'] = dfcs_fixhistlen['environment'].astype(int)
df_ = dfcs_fixhistlen[['environment','subject','trials','error_pscadj_abs_Tan29']]
assert not df_.duplicated(['subject','trials']).any()

In [None]:
all_suffixes = 'mav,std,invstd,mavsq,mav_d_std,mav_d_var,Tan,invmavsq,invmav,std_d_mav,invTan'.split(',')

varn0 = 'error_pscadj'
n = 3
for suffix in all_suffixes:
    s = f'{varn0}_{suffix}{n}'
    if s not in dfcs_fixhistlen.columns:
        print(s)
    assert s in dfcs_fixhistlen.columns

print( all_suffixes )

In [None]:
import gc
prl = []

In [None]:
# TODO: make prev error abs

In [None]:
dfcs_fixhistlen.columns

In [None]:
[col for col in dfcs_fixhistlen.columns if col.endswith('error_mav4')]

In [None]:
# run long calc
from joblib import Parallel, delayed
from itertools import product
import statsmodels.api as sm
import statsmodels.formula.api as smf
import warnings
from statsmodels.tools.sm_exceptions import ConvergenceWarning    
from numpy.linalg import LinAlgError
import traceback
import gc
gc.collect()
n_jobs_inside = 1

# Define the function to be executed in parallel
def run_model(args, ret_res = False):
    dfcs_fixhistlen, cocoln, std_mavsz_, varn0, varn_suffix = args
    varn = f'{varn0}_{varn_suffix}{std_mavsz_}'
    df_ = dfcs_fixhistlen.dropna(subset=[varn, 'err_sens'])
    df_ = df_[~np.isinf(df_[varn])]

    excfmt = None
    nstarts = 1
    result = None
    if cocoln == 'None':
        s,s2 = f"err_sens ~ {varn}","1"
        model = smf.mixedlm(s, df_, 
                    groups=df_["subject"])
        with warnings.catch_warnings():
            warnings.filterwarnings('ignore',category=ConvergenceWarning)
            result = model.fit()
        results = {(s,s2): result}
    else:
        flas = []
        s,s2 = f"err_sens ~ C({cocoln}) + {varn} + C({cocoln}) * {varn} + {varn} * prev_error_pscadj_abs",\
            f"~C({cocoln})"; flas += [(s,s2)]
        s,s2 = f"err_sens ~ C({cocoln}) + {varn} + C({cocoln}) * {varn} + {varn} * prev_error_pscadj_abs",\
            f"1"; flas += [(s,s2)]
        s,s2 = f"err_sens ~ C({cocoln}) + {varn} + C({cocoln}) * {varn}", f"~C({cocoln})"; flas += [(s,s2)]
        s,s2 = f"err_sens ~ C({cocoln}) + {varn} + C({cocoln}) * {varn}","1";  flas += [(s,s2)]
        
        s,s2 = f"err_sens ~ C({cocoln}) + {varn}",f"~C({cocoln})"; flas += [(s,s2)]
        s,s2 = f"err_sens ~ C({cocoln}) + {varn}","1"; flas += [(s,s2)]        
        
     
        results = {}
        for s,s2 in flas:
            try:                
                model = smf.mixedlm(s, df_.copy(), 
                            groups=df_["subject"], re_formula=s2)
                with warnings.catch_warnings(record=True) as w:
                    ###warnings.filterwarnings('ignore',category=ConvergenceWarning)     
                    # n_jobs argument does not really work :(
                    result = model.fit(n_jobs =n_jobs_inside)
                    wmess = []
                    for warning in w:
                        wmess += [warning.message]
                    result.converged2 = result.converged and \
                        ( not (result.params.isna().any() | result.pvalues.isna().any()) )
                    result.wmess = wmess

            except LinAlgError as le:
                excfmt = traceback.format_exc()
                result = None
            results[(s,s2)] = result
        
    s2summary = {}
    for stpl,result in results.items():        
        if (result is not None) and result.converged:
            #result.remove_data()
            summary = result.summary()
            summary.tables[0].loc[5,2] = 'Converged2:'
            summary.tables[0].loc[5,3] = 'Yes' if result.converged2 else 'No'
            summary.wmess = result.wmess
            summary.params = result.params
            summary.pvalues = result.pvalues
        else:
            summary = None
        s2summary[stpl] = summary
    print(args[1:])
    r = {'cocoln': cocoln, 'histlen': std_mavsz_,
            'varn': varn, 'varn0':varn0, 'varn_suffix':varn_suffix,             
             'excfmt':excfmt,
            's2summary': s2summary, 'retention_factor':df_.iloc[0]['retention_factor_s']}
            #'res': result}
            #'nstarts':nstarts,
    if ret_res:
        r['s2res'] = results
    return r

N = 25
# Create args array using product
cocols = ['None', 'env', 'ps2_']
#cocols = [ 'env'] #, 'ps2_']
#std_mavsz_range = range(2, 30)
std_mavsz_range = list(range(3,N)) + list( range(N, 40, 3) )
#std_mavsz_range = range(2, 20)
#std_mavsz_range = range(2, 12)
#std_mavsz_range = range(2, 3)
varn0s = ['error_pscadj', 'error_pscadj_abs']
#varn_suffixes = ['std', 'invstd', 'mavsq', 'mav_d_std', 'mav_d_var', 'Tan']
varn_suffixes = all_suffixes

# shorter
cocols = [ 'env', 'ps2_']
#varn0s = ['error'] #, 'error_pscadj', 'error_pscadj_abs']
#varn_suffixes = ['std', 'invstd', 'Tan']
#N_ = 10
#std_mavsz_range = list(range(3,N_)) + list( range(N_, 32, 3) )

# #std_mavsz_range = range(2, 15)
# std_mavsz_range = range(2, 30)
# varn0s = ['error_pscadj_abs']
# #varn0s = ['error_pscadj', 'error_pscadj_abs']
# #cocols = [ 'env', 'ps2_']
# #cocols = [ 'ps2_']
# cocols = [ 'env']

args = list(product(cocols, std_mavsz_range, varn0s, varn_suffixes))
print(len(args))
# Number of processes
#n_jobs = 10  # Use all available CPUs even with one job
#n_jobs = 5
n_jobs = 1

ind = 0
if n_jobs > 1:
    # Execute in parallel
    backend = 'multiprocessing' # 'loky'
    #backend = 'loky' 
    prl = Parallel(n_jobs=n_jobs, backend = backend)\
        (delayed(run_model)( (dfcs_fixhistlen,*arg) ) for arg in args)
else:
    for arg in args:
        prl += [run_model((dfcs_fixhistlen,*arg))]
#     for arg in args[69:69+1]:
#         prl += [run_model((dfcs_fixhistlen,*arg),ret_res=True)]
        
        if len(prl) >= 100:            
            s_ = str(datetime.now())[:-7].replace(' ','_')
            np.savez( pjoin(path_data, f'prl_{ind}_{s_}'), prl )
            ind += 1
            del prl
            gc.collect()
            prl = []
    # to save the last
    s_ = str(datetime.now())[:-7].replace(' ','_')
    np.savez( pjoin(path_data, f'prl_{ind+1}_{s_}'), prl )


gc.collect()

In [None]:
print(len(args))