In [None]:
%load_ext autoreload
%autoreload 2

import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import statsmodels.api as sm
import pingouin as pg
from os.path import join as pjoin
from bmp_config import path_data, ps_2nice
from bmp_behav_proc import *
from datetime import datetime

fnf = pjoin(path_data,'df_all_multi_tsz__.pkl.zip')
print(fnf)
print( str(datetime.fromtimestamp(os.stat(fnf).st_mtime)))
df_all_multi_tsz = pd.read_pickle(fnf)
df = df_all_multi_tsz.query('trial_shift_size == 1 and trial_group_col_calc == "trialwe" '
                           ' and retention_factor_s == "0.924"').copy().sort_values(['subject','trials'])
df,dfall,ES_thr,envv,pert = addBehavCols2(df);

dfall_not0trial = dfall.query('trialwpertstage_wb > 0').copy()

# Fig 3

In [None]:
# compute covariance and normalize it
def f(df_):
    r = np.cov( df_['err_sens'].values, df_['prev_error_abs'].values  )
    return r[0,1] #/ r[1,1]
    
covs_prevabserr_per_subj = dfall_not0trial.groupby(['subject','env'], observed=True).apply(f, include_groups=False)
covs_prevabserr_per_subj = covs_prevabserr_per_subj.to_frame().reset_index().rename(columns={0:'covabsprev'})

display(covs_prevabserr_per_subj.iloc[:6])
dfall_aug = dfall_not0trial.merge(covs_prevabserr_per_subj, on = ['subject','env'])

dfall_aug = dfall_aug.assign(err_sens_prevabserrcorr = \
            dfall_aug['err_sens'] - dfall_aug['prev_error_abs'] * dfall_aug['covabsprev'])

In [None]:
dfall_aug = dfall_aug.reset_index().sort_values(['subject_ind','trial_index'])
if 'level_0' in dfall_aug.columns:
    dfall_aug = dfall_aug.drop(columns='level_0')
use_const = True # it has to be there
# compute residual from OLS
def f(df_):
    # Step 1: Fit a regression model for ES on preverrabs
    preverr = df_['prev_error_abs']  # Adds a constant term to the predictor
    if use_const:
        preverr = sm.add_constant(preverr)  # Adds a constant term to the predictor
    model1 = sm.OLS(df_['err_sens'], preverr).fit()

    trial = df_['trialwpertstage_wb']
    if use_const:
        trial = sm.add_constant(trial)  # Adds a constant term to the predictor
    model2 = sm.OLS(df_['err_sens'], trial).fit()

    #X3 = sm.add_constant(df_['trial_index'])  # Adds a constant term to the predictor
    model3 = sm.OLS(df_['trialwpertstage_wb'], preverr).fit()

    dftmp = pd.DataFrame( {'trial_index':df_['trial_index'].values, 'resid1':model1.resid, 'resid2':model2.resid, 
                           'resid3':model3.resid } )
    return dftmp
    #return pd.DataFrame( dict(zip(['trial_index','resid'], [df_['trial_index'].values ,model.resid] ) ) )

groupcol = 'ps2_'
#groupcol = 'pert_stage'
dfr = dfall_aug.reset_index().groupby(['subject',groupcol], observed=True).\
    apply(f, include_groups=False).reset_index()
#dfall = dfall.set_index(['subject','trial_index']) 
dfall_aug = dfall_aug.set_index(['subject','trial_index']) 
dfall_aug['err_sens_prev_error_abs_resid'] = dfr.set_index(['subject','trial_index']) ['resid1']
dfall_aug['err_sens_trial_resid']          = dfr.set_index(['subject','trial_index']) ['resid2']
dfall_aug['trial_prev_error_abs_resid']    = dfr.set_index(['subject','trial_index']) ['resid3']

#dfall = dfall.reset_index()
dfall_aug = dfall_aug.reset_index()

In [None]:
nb = 60
dfall_aug['trial_prev_error_abs_resid_bin'] = pd.cut(dfall_aug['trial_prev_error_abs_resid'], bins = nb)
dfall_aug['trial_prev_error_abs_resid_binmid'] = dfall_aug['trial_prev_error_abs_resid_bin'].apply(lambda interval: interval.mid)
from figure.plots import relplot_multi
fg,_ = relplot_multi(data=dfall_aug, ys=['err_sens_prev_error_abs_resid' ], 
              x='trial_prev_error_abs_resid_binmid', col='ps2_',
              kind='line',facet_kws={'sharex':False}, errorbar='sd');#, row='pert_stage_wb',)
fg.refline(y=0)

In [None]:
from figure.plots import relplot_multi
fg,_ = relplot_multi(data=dfall_aug, ys=['err_sens_prev_error_abs_resid' ], 
              x='trial_prev_error_abs_resid_binmid', col='ps2_',
              kind='line',facet_kws={'sharex':False}, errorbar='sd');#, row='pert_stage_wb',)
fg.refline(y=0)

In [None]:
# calculate correlations (just to print numbers in the corner)
print('Comparing correlation values with zero')
for method in ['spearman']:
    pcorrs_per_subj_me_0,_ = corrMean(dfall_aug, covar = 'prev_error_abs', 
                               stagecol = 'ps2_', method=method)
    
    corrs_per_subj_me_,corrs_per_subj  = corrMean(dfall_aug, stagecol = 'ps2_', method=method)
    pcorrs_per_subj_me_2,pcorrs_per_subj = corrMean(dfall_aug, coln='err_sens_prev_error_abs_resid', 
            coltocorr='trial_prev_error_abs_resid', stagecol = 'ps2_', method=method)
    
    # stats for Fig 3 caption
    #from behav_proc import compare0
    def f(df):
        if len(df) == 0:
            return None
        ttrs = compare0(df, 'r', cols_addstat=['r'])
        return ttrs
    
    #print(getAddInfo())
    
    print('ps_-sep corr')
    ttrs = corrs_per_subj.\
        groupby(['method','ps2_']).apply(f, include_groups=False)
    ttrs = multi_comp_corr(ttrs.reset_index(), 'holm')
    corrs_gt0 = ttrs.query('alt == "greater"').set_index('ps2_') # for later
    corrs_sig = ttrs.query('pval <= 0.05').reset_index()
    #cols = ['method','ps2_','dof','T','pval','alt','ttstr','r_mean', 'r_std']
    cols = ['method','ps2_','ttstr','r_mean','pval', 'r_std','mc_corr_method']
    with pd.option_context('display.precision', 2):
        display(corrs_sig[cols])
    
    print('ps_-sep partial corr')
    ttrs = pcorrs_per_subj.\
        groupby(['method','ps2_']).apply(f, include_groups=False)
    ttrs = multi_comp_corr(ttrs.reset_index(), 'holm')
    pcorrs_gt0 = ttrs.query('alt == "greater"').set_index('ps2_') # for later
    pcorrs_sig = ttrs.query('pval <= 0.05').reset_index()
    with pd.option_context('display.precision', 2):
        display(pcorrs_sig[cols])

In [None]:
display(pcorrs_per_subj_me_2.loc[('mestd*0',slice(None) )].loc['pre'])

In [None]:
display(pcorrs_per_subj_me_0,pcorrs_per_subj_me_2)
print ( (pcorrs_per_subj_me_0['r'] - pcorrs_per_subj_me_2['r']).abs().max() )
print( (pcorrs_per_subj_me_0['pval'] - pcorrs_per_subj_me_2['pval']).abs().max() )

In [None]:
pswb2r_ = corrs_per_subj_me_.loc[('mestd*0',slice(None))]
pswb2r = pswb2r_.to_dict()
pswb2pr_ = pcorrs_per_subj_me_2.loc[('mestd*0',slice(None))] # corr of two residuals
pswb2pr = pswb2pr_.to_dict()

print(pswb2r_[['r','pval','method']]) #mean pval
print(pswb2pr_[['r','pval','method']]) #mean pval

def f(row):
    #ps = row['pert_stage_wb']
    ps = row['ps2_']
    r = pswb2r['r'][ps]
    std_x = pswb2r['std_x'][ps]
    mean_x = pswb2r['mean_x'][ps]
    std_y = pswb2r['std_y'][ps]
    mean_y = pswb2r['mean_y'][ps]
    xs = row['trialwpertstage_wb']
    return mean_y + r * (xs - mean_x) / std_x * std_y 
dfall_aug['pred'] = dfall_aug.apply(f, axis=1)
#corrs_per_subj_me_ES['r'] * dfc['error_abs']

def f(row):
    #ps = row['pert_stage_wb']
    ps = row['ps2_']
    if ps == -1:
        return None
    r = pswb2pr['r'][ps]
    std_x = pswb2pr['std_x'][ps]
    mean_x = pswb2pr['mean_x'][ps]
    std_y = pswb2pr['std_y'][ps]
    mean_y = pswb2pr['mean_y'][ps]
    #mean_z =  pswb2pr['mean_z'][ps]
    xs = row['trial_prev_error_abs_resid']
    return mean_y + r * (xs - mean_x) / std_x * std_y# + mean_z
dfall_aug['ppred'] = dfall_aug.apply(f, axis=1)

In [None]:
#hue_order = subenv2color.keys() #dfall_aug['pert_stage_wb'].unique()
col_order = ['pre', 'pert', 'washout', 'rnd']
#hues = [[0],[1,3],[2,4],[5]]
hues = None
coord_let = (0,1)
coord_let_shift = (-50,20)
#palette=['blue', 'orange', 'green', 'olive','cyan','brown']
from figure import subenv2color
hue_order, palette = list(subenv2color.keys()), list(subenv2color.values())
# TODO: start of pert
# TODO: sd instead of se
#df_ = dfc.query('trial_shift_size == 1')
#df_ = dfc.query('trial_shift_size == 1')

from figure.plots import make_fig3_v2
#with sns.plotting_context(font_scale=10.5):
fnfbs = make_fig3_v2(dfall_aug, palette, hue_order, col_order, ps_2nice, 
    hues, pswb2r, pswb2pr, corrs_gt0, 
    pcorrs_gt0, coord_let, coord_let_shift, show_plots=1, show_reg = 1, 
    hue='ps2_', show_ttest_alt_type = False,
    #fontsize_r = None , fontsize_panel_let = None,  fsz_lab = None, fontsize_title= 10,
    fontsize_r = 15 , fontsize_panel_let = 24,  fsz_lab = 18, fontsize_title= 22)

In [None]:
%matplotlib inline
# merge svg files
from bmp_config import path_fig
from figure.imgfilemanip import *
svg_files = [ fnfb + '.svg' for fnfb in fnfbs ]
restree = stack_svg(svg_files,'vertical')

fnfout = pjoin(path_fig, 'behav', f'Fig3_stacked_dynESps2_2.svg')
restree.write(fnfout)
print(f"SVG files have been combined and saved as {fnfout}")

from IPython.display import SVG, display, Image
# Display the stacked SVG file
print('svg pic:')
display(SVG(filename=fnfout))

svg2png(fnfout)

# Stats

In [None]:
with pd.option_context('display.precision', 2):
    print('ps_-sep corr')
    display(corrs_sig[cols])
    print('ps_-sep partial corr')
    display(pcorrs_sig[cols])

In [None]:
mccm = 'holm'

cps = corrs_per_subj.reset_index()
ttrssig,ttrs = comparePairs(cps,'r', 'ps2_', pooled=0, alt=['greater'], updiag=False, paired=True, multi_comp_corr_method = mccm)
ttrs_corrs = ttrs

ttrs_corrs['corr_type'] = 'direct_corr'
assert cps.method.nunique() == 1
print(cps.method.unique())
cols = ['ttstr','T','pval','dof','starcode','mc_corr_method']
with pd.option_context('display.precision', 2):
    display(ttrssig[cols])

print('partial corr:')
pcps = pcorrs_per_subj.reset_index()
ttrssig,ttrs = comparePairs(pcps,'r', 'ps2_', pooled=0, alt=['greater'], updiag=False, paired=True, multi_comp_corr_method = mccm)
ttrs_pcorrs = ttrs
ttrs_pcorrs['corr_type'] = 'partial_corr'
assert cps.method.nunique() == 1
print(cps.method.unique())
cols = ['ttstr','T','pval','dof','starcode','mc_corr_method']
with pd.option_context('display.precision', 2):
    display(ttrssig[cols])

In [None]:
with pd.option_context('display.precision', 2):
    ttrs_corrs_and_pcorrs = pd.concat([ttrs_corrs,ttrs_pcorrs])
    display(ttrs_corrs_and_pcorrs.query('pval <= 0.05').sort_values(['ttstr','corr_type']) [['corr_type'] + cols] )

In [None]:
with pd.option_context('display.precision', 2):
    display(ttrs_corrs.query('ttstr == "washout > pert"')[cols])

    display(ttrs_pcorrs.query('ttstr == "washout > pert"')[cols])

# some stats to print (from behav_manuscript_plots_ju.ipynb)

In [None]:
from datetime import datetime
fnf = pjoin(path_data,'df_all_multi_tsz__.pkl.zip')
print(fnf)
print( str(datetime.fromtimestamp(os.stat(fnf).st_mtime)))
df_all_multi_tsz = pd.read_pickle(fnf)
df = df_all_multi_tsz.query('trial_shift_size == 1 and trial_group_col_calc == "trialwe" '
                           ' and retention_factor_s == "0.924"').copy().sort_values(['subject','trials'])
df,dfall,ES_thr,envv,pert = addBehavCols2(df);

dfall_not0trial = dfall.query('trialwpertstage_wb > 0').copy()