In [None]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from matplotlib import lines
from matplotlib.patches import Patch
from matplotlib.pyplot import Line2D
from scipy import stats

In [None]:
df = pd.read_csv('../data/data2.csv')
df['action'] = df['choice']-1
df['block'] = df['block']-1
df['trial'] = df['trial']-1
df['subject'] = df['subject']-1

df_p = pd.read_csv('../results/gershman_trnn_param.csv')

predict_beta_array = []
predict_gamma_array = [] 

for i in range(44):
    predict_beta_array.append(df_p[df_p.subject==i].beta.values)
    predict_gamma_array.append(df_p[df_p.subject==i].gamma.values)
    
full_df = df.join(df_p.drop(columns='subject'))

full_df['aux'] = (full_df.mu1<full_df.mu2).astype(int)+1
full_df['best'] = (full_df['aux']==full_df['choice']).astype(int)

full_df['diff'] = full_df.mu2-full_df.mu1
full_df['diff_c'] = pd.cut(full_df['diff'],10)


color_blue = '#1761B0'
color_red = '#D2292D'
color_grey = '#555555'
sns.blend_palette(colors=['#1761B0','#D2292D'],n_colors=7)



In [None]:
gershman_bay  = np.array([0.45278034, 0.30322403, 0.2819705 , 0.30908147, 0.26939126,
       0.30010185, 0.31399581, 0.40598829, 0.28786791, 0.24771328,
       0.46413853, 0.17964524, 0.31796209, 0.41342347, 0.2602049 ,
       0.33621256, 0.25868461, 0.35712081, 0.42422287, 0.25254991,
       0.32944093, 0.3492151 , 0.35065013, 0.35662145, 0.26294523,
       0.32671202, 0.47430834, 0.35618661, 0.41753451, 0.34495363,
       0.36580735, 0.3851236 , 0.39693518, 0.48950065, 0.22544874,
       0.32891676, 0.22619385, 0.29087422, 0.22772049, 0.31904306,
       0.38006013, 0.37470428, 0.31770314, 0.30496219])

gershman_ind = pd.read_csv('../results/gershman_individual_theoretical.csv')['bce'].values
gershman_trnn = pd.read_csv('../results/gershman_trnn.csv')['bce'].values
gershman_drnn =  pd.read_csv('../results/gershman_drnn.csv')['bce'].values


df_gershman = pd.DataFrame({'gershman_theoretical':gershman_ind,'gershman_bay':gershman_bay,
                            'gershman_trnn':gershman_trnn,'gershman_drnn':gershman_drnn})

df_gershman_t = pd.DataFrame({'bce':df_gershman.values.T.flatten(),
                            'Model':np.repeat(['Theoretical','Bayesian','t-RNN','d-RNN'],44),
                            'data':np.repeat('Gershman',44*4)})

pla = ['tab:green','#ffc100','#1761B0',sns.blend_palette(['tab:green','#ffc100','#D2292D','grey','#1761B0'])[2]]

fig , ax0 = plt.subplots(1,1,figsize=(3.8,4))

df_all = df_gershman_t

sns.barplot(ax=ax0,data=df_all,
            x='data',y='bce',hue='Model',
            palette=pla,
            edgecolor='k',
            errorbar="se",orient='v')


ax0.legend_.remove()

ylim = ax0.get_ylim()
yrange = ylim[1] - ylim[0]

h = 0.02*0

line = lines.Line2D([-.3,-.3,0.1,0.1], [.362,.362+h,.362+h,.362], lw=1, c='0.2', transform=ax0.transData)
line.set_clip_on(False)
ax0.add_line(line)
ax0.annotate('*', xy=(np.mean([-.3, 0.1]),.36),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=18, clip_on=False, annotation_clip=False)

line = lines.Line2D([-0.1,-0.1,0.1,0.1], [.348,.348+h,.348+h,.348], lw=1, c='0.2', transform=ax0.transData)
line.set_clip_on(False)
ax0.add_line(line)
ax0.annotate('**', xy=(np.mean([-0.1, 0.1]),.344),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=18, clip_on=False, annotation_clip=False)

line = lines.Line2D([0.1,0.1,0.3,0.3], [.338,.338+h,.338+h,.338], lw=1, c='0.2', transform=ax0.transData)
line.set_clip_on(False)
ax0.add_line(line)
ax0.annotate('n.s.', xy=(np.mean([0.13, 0.3]),.338),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=15, clip_on=False, annotation_clip=False)

ax0.set_xlim(-.45,.45)
ax0.set_xlabel('',fontsize=16)
ax0.set_xticks([-.3,-0.1,0.1,.3])
ax0.set_xticklabels(['Hybrid\nexplor.','Bay','t-RNN','d-RNN'],fontsize=14,rotation=0)

ax0.set_ylim(0.2,0.37)
ax0.set_yticks([.25,.35])
ax0.set_yticklabels([.2,.3],fontsize=20)
ax0.set_ylabel('Error (lower is better)',size=20,labelpad=2)
# ax0.set_yticklabels([str(x)[1:] for x in np.round(ax0.get_yticks(), 3)])


sns.despine()
plt.savefig('../plots/fig_4A_REV.pdf',bbox_inches='tight')
plt.show()