In [None]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from matplotlib import lines
from matplotlib.patches import Patch
from matplotlib.pyplot import Line2D
from scipy import stats

In [None]:
df = pd.read_csv('../data/data2.csv')
df['action'] = df['choice']-1
df['block'] = df['block']-1
df['trial'] = df['trial']-1
df['subject'] = df['subject']-1

df_p = pd.read_csv('../results/gershman_trnn_param.csv')

predict_beta_array = []
predict_gamma_array = [] 

for i in range(44):
    predict_beta_array.append(df_p[df_p.subject==i].beta.values)
    predict_gamma_array.append(df_p[df_p.subject==i].gamma.values)
    
full_df = df.join(df_p.drop(columns='subject'))

full_df['aux'] = (full_df.mu1<full_df.mu2).astype(int)+1
full_df['best'] = (full_df['aux']==full_df['choice']).astype(int)

full_df['diff'] = full_df.mu2-full_df.mu1
full_df['diff_c'] = pd.cut(full_df['diff'],10)


color_blue = '#1761B0'
color_red = '#D2292D'
color_grey = '#555555'
sns.blend_palette(colors=['#1761B0','#D2292D'],n_colors=7)

In [None]:
gershman_ind = pd.read_csv('../results/gershman_individual_theoretical.csv')['bce'].values
gershman_trnn = pd.read_csv('../results/gershman_trnn.csv')['bce'].values
gershman_drnn =  pd.read_csv('../results/gershman_drnn.csv')['bce'].values

df_gershman = pd.DataFrame({'gershman_theoretical':gershman_ind,
                            'gershman_trnn':gershman_trnn,'gershman_drnn':gershman_drnn})

df_gershman_t = pd.DataFrame({'bce':df_gershman.values.T.flatten(),
                            'Model':np.repeat(['Theoretical','t-RNN','d-RNN'],44),
                            'data':np.repeat('Gershman',44*3)})

pla = ['tab:green','#1761B0',sns.blend_palette(['tab:green','#ffc100','#D2292D','grey','#1761B0'])[2]]

fig , ax0 = plt.subplots(1,1,figsize=(3.5,4))

df_all = df_gershman_t

sns.barplot(ax=ax0,data=df_all,
            x='data',y='bce',hue='Model',
            palette=pla,
            edgecolor='k',
            errorbar="se",orient='v')


ax0.legend_.remove()

ylim = ax0.get_ylim()
yrange = ylim[1] - ylim[0]

h = 0.02*0

line = lines.Line2D([-.26,-.26,0.01,0.01], [.355,.355+h,.355+h,.355], lw=1, c='0.2', transform=ax0.transData)
line.set_clip_on(False)
ax0.add_line(line)
ax0.annotate('*', xy=(np.mean([-.26, 0.02]),.34+0.01),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=18, clip_on=False, annotation_clip=False)

line = lines.Line2D([0.01,0.01,0.26,0.26], [.34,.34+h,.34+h,.34], lw=1, c='0.2', transform=ax0.transData)
line.set_clip_on(False)
ax0.add_line(line)
ax0.annotate('n.s.', xy=(np.mean([0.03, 0.26]),.33+0.01),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=15, clip_on=False, annotation_clip=False)

ax0.set_xlim(-.45,.45)
ax0.set_xlabel('',fontsize=20)
ax0.set_xticks([-.27,0,.27])
ax0.set_xticklabels(['Hybrid\nexplor.','t-RNN','d-RNN'],fontsize=16,rotation=0)

ax0.set_ylim(0.2,0.37)
ax0.set_yticks([.25,.35])
ax0.set_yticklabels([.2,.3],fontsize=20)
ax0.set_ylabel('Error (lower is better)',size=18,labelpad=2)
# ax0.set_yticklabels([str(x)[1:] for x in np.round(ax0.get_yticks(), 3)])

sns.despine()
# plt.savefig('../plots/fig_4A.pdf',bbox_inches='tight')
plt.show()

In [None]:
# example of ungreedy action 
cur = full_df[(full_df.subject==43) & (full_df.block>=5) & (full_df.block<=7)].reset_index()

fig, (ax0,ax1) = plt.subplots(2,1,figsize=(6,3))

# ax0.set_title('Example of ungreedy action',size=18,pad=40)

sns.scatterplot(ax=ax0,x=np.arange(0,10),y=cur.action[10:20],color=color_red,marker='o', s=30)
sns.lineplot(ax=ax0,x=np.arange(0,10),y=1-cur.p_0[10:20],lw=2)
sns.scatterplot(ax=ax0,x=np.arange(0,10),y=1-cur.p_0[10:20],marker='o', s=30)
ax0.axhline(y=0.5,ls='--',color='k')

sns.lineplot(ax=ax1,x=np.arange(0,10),y=cur.beta[10:20],lw=2,color=color_blue)

for i in range(10):
    ax0.text(x=i-0.12,
             y=cur.action[10:20].values[i]+0.1,
             s=cur.reward[10:20].values[i],
             size=15,
             color=color_red
            )

for ax in [ax0,ax1]:
    ax.set_xlim(-.2,9.3)
    
ax0.set_xticks([0,1,2,3,4,5,6,7,8,9])
ax0.set_xticklabels([])
ax0.text(x=4.615,y=1.05,s=r'$\downarrow$',size=22)

ax1.set_xticks([0,1,2,3,4,5,6,7,8,9])
ax1.tick_params(axis='x', which='major', labelsize=16)

ax0.set_ylim(-0.1,1.1)
ax0.set_ylabel(r'$p(a_{R})$',size=22,labelpad=8)
ax0.tick_params(axis='y', which='major', labelsize=18)

ax1.set_xlabel('Trial',size=18)
ax1.set_ylim(0.7,3.3)
ax1.set_yticks([1,3])
ax1.set_ylabel(r'$\beta$',size=22,labelpad=6)
ax1.tick_params(axis='y', which='major', labelsize=18)

sns.despine()
plt.savefig('../plots/fig_4B.pdf',bbox_inches='tight')
plt.show()


In [None]:
# mean beta lapses action ~ 317 trials
idx = full_df.query('(mu1>mu2 and choice==2 and trial>4) or (mu1<mu2 and choice==1 and trial>4)').index.values

x = [] 
for i in idx:
    cur = full_df[(full_df.index>i-5) & (full_df.index<i+5)].copy()
    cur['trial'] = np.arange(-4,5)
#     fig, (ax0,ax1) = plt.subplots(2,1,figsize=(6,3))

#     sns.scatterplot(ax=ax0,x=np.arange(0,9),y=cur.action,color=color_red,marker='o', s=30)
#     sns.lineplot(ax=ax0,x=np.arange(0,9),y=1-cur.p_0,lw=2)
#     sns.scatterplot(ax=ax0,x=np.arange(0,9),y=1-cur.p_0,marker='o', s=30)
#     ax0.axhline(y=0.5,ls='--',color='k')
#     sns.lineplot(ax=ax1,x=np.arange(0,9),y=cur.beta,lw=2,color=color_blue)

#     plt.show()
    
    x.append(cur)
    
mean_beta = pd.concat(x).groupby('trial')['beta'].mean().values
sem_beta = pd.concat(x).groupby('trial')['beta'].sem().values

fig , ax1 = plt.subplots(1,1,figsize=(5,3))

ax1.fill_between(np.arange(-4,5),mean_beta-sem_beta,mean_beta+sem_beta,alpha=.2)
sns.scatterplot(ax=ax1,x=np.arange(-4,5),y=mean_beta,s=30,color=color_blue,edgecolor='k')
sns.lineplot(ax=ax1,x=np.arange(-4,5),y=mean_beta,color=color_blue)
ax1.axvline(x=0,ls='--',color='k')

ax1.set_xticks(np.arange(-4,5,2))
ax1.tick_params(axis='x', which='major', labelsize=16)
ax1.set_xlabel('Trial position from ungreedy action',size=16)

ax1.set_yticks([1.7,2.1]) 
ax1.set_ylim(1.6,2.2) 
ax1.tick_params(axis='y', which='major', labelsize=16)
ax1.set_ylabel(r'$\beta$',size=25,labelpad=-20)
fig.tight_layout()
sns.despine()
# plt.savefig('../plots/fig_4C.pdf') # ,bbox_inches='tight'
plt.show()


In [None]:
color_blue = sns.blend_palette(['#1761B0',sns.color_palette("tab10" )[4]])[3]
color_grey = 'grey'
# easy vs. hard
fig, (ax0,ax1,ax2) = plt.subplots(1,3,figsize=(16,4),gridspec_kw={'width_ratios': [1, 2.5, 2.5]})

# ax0.set_title('Correct choice',size=24,pad=10)

c = full_df.query('abs(mu1-mu2)<=1')
acc_easy = c.groupby('subject')['best'].mean()

c = full_df.query('abs(mu1-mu2)>=19')
acc_hard = c.groupby('subject')['best'].mean()

dd = pd.DataFrame({'ACC':np.concatenate([acc_hard.values,acc_easy.values]),
                   'Condtion':np.concatenate([np.repeat(0,len(acc_hard)),np.repeat(1,len(acc_easy))])
            })

sns.barplot(ax=ax0,data=dd,x='Condtion',y='ACC',
            palette=[color_blue,'grey'],
            ec='k')

ax0.set_ylim(0,1)
ax0.set_yticks([.1,.9])
ax0.set_yticklabels([.1,.9],fontsize=24)
ax0.set_ylabel('Accuracy',size=25,labelpad=-10)
ax0.set_yticklabels([str(x)[1:] for x in np.round(ax0.get_yticks(), 3)])


mean_beta_hard = full_df.query('abs(mu1-mu2)<=1').groupby('trial')['beta'].mean()
sem_beta_hard = full_df.query('abs(mu1-mu2)<=1').groupby('trial')['beta'].sem()

mean_beta_easy = full_df.query('abs(mu1-mu2)>=19').groupby('trial')['beta'].mean() 
sem_beta_easy = full_df.query('abs(mu1-mu2)>=19').groupby('trial')['beta'].sem() 

# ax1.set_title('Random explortion',size=24,pad=10)

sns.lineplot(ax=ax1,x=np.arange(0,10),y=mean_beta_easy,lw=3,color=color_blue)
sns.scatterplot(ax=ax1,x=np.arange(0,10),y=mean_beta_easy,s=50,color=color_blue)

ax1.fill_between(np.arange(0,10),mean_beta_easy-sem_beta_easy,
                                 mean_beta_easy+sem_beta_easy,alpha=.2,color=color_blue)


sns.lineplot(ax=ax1,x=np.arange(0,10),y=mean_beta_hard,lw=3,color=color_grey)
sns.scatterplot(ax=ax1,x=np.arange(0,10),y=mean_beta_hard,s=50,color=color_grey)
ax1.fill_between(np.arange(0,10),mean_beta_hard-sem_beta_hard,
                                 mean_beta_hard+sem_beta_hard,alpha=.2,color=color_grey)


ax1.set_ylabel(r'$\beta$',size=30,labelpad=-30)
ax1.tick_params(axis='y', which='major', labelsize=24)
ax1.set_yticks([2.,2.4])
ax1.set_ylim(1.95,2.45)

ax1.set_xticks([1,5,9])
ax1.tick_params(axis='x', which='major', labelsize=24)
ax1.set_xlabel('Trial',size=24,labelpad=0)


mean_gamma_hard = full_df.query('abs(mu1-mu2)<=1').groupby('trial')['gamma'].mean()
sem_gamma_hard = full_df.query('abs(mu1-mu2)<=1').groupby('trial')['gamma'].sem()

mean_gamma_easy = full_df.query('abs(mu1-mu2)>=19').groupby('trial')['gamma'].mean() 
sem_gamma_easy = full_df.query('abs(mu1-mu2)>=19').groupby('trial')['gamma'].sem() 


# ax2.set_title('Direct explortion',size=24,pad=10)

sns.lineplot(ax=ax2,x=np.arange(0,10),y=mean_gamma_easy,lw=3,color=color_blue)
sns.scatterplot(ax=ax2,x=np.arange(0,10),y=mean_gamma_easy,s=50,color=color_blue)
ax2.fill_between(np.arange(0,10),mean_gamma_easy-sem_gamma_easy,
                                 mean_gamma_easy+sem_gamma_easy,alpha=.2,color=color_blue)


sns.lineplot(ax=ax2,x=np.arange(0,10),y=mean_gamma_hard,lw=3,color=color_grey)
sns.scatterplot(ax=ax2,x=np.arange(0,10),y=mean_gamma_hard,s=50,color=color_grey)
# ax2.errorbar(np.arange(0,10), mean_gamma_hard, sem_gamma_hard, color=color_grey, ecolor='k')
ax2.fill_between(np.arange(0,10),mean_gamma_hard-sem_gamma_hard,
                                 mean_gamma_hard+sem_gamma_hard,alpha=.2,color=color_grey)


ax1.set_xlim(-0.1,9.1)
ax2.set_xlim(-0.1,9.1)
ax2.set_ylabel(r'$\gamma$',size=30,labelpad=-20)
ax2.tick_params(axis='y', which='major', labelsize=24)
ax2.set_yticks([0.6,0.8])
ax2.set_ylim(0.58,0.82)


ax2.set_xticks([1,5,9])
ax2.tick_params(axis='x', which='major', labelsize=24)
ax2.set_xlabel('Trial',size=24,labelpad=0)

legend_elements = [Line2D([0],[0] ,lw=5, color=color_blue, label='Easy blocks'+r'$\geq19$'),
                   Line2D([0],[0] ,lw=5, color=color_grey, label='Hard blocks'+r'$\leq1$'),
                  ]
ax1.legend(handles=legend_elements,loc='lower left', fontsize=16,framealpha=0.5)

# full_df['diff'] = np.abs(full_df.mu1-full_df.mu2)
# full_df['diff'].quantile(0.1)

# stats.ttest_ind(full_df.query('abs(mu1-mu2)<=1')['beta'],
#                 full_df.query('abs(mu1-mu2)>=19')['beta'],
#                    alternative='less')

# stats.ttest_ind(full_df.query('abs(mu1-mu2)<=1')['gamma'],
#                 full_df.query('abs(mu1-mu2)>=19')['gamma'],
#                 alternative='greater')

ax0.set_xticklabels([r'$\geq19$',r'$\leq1$'],fontsize=18)
ax0.set_xlabel('Value diffrence',size=22)

sns.despine()

fig.tight_layout()

plt.savefig('../plots/fig_4C_D_E.pdf',bbox_inches='tight')
plt.show()


In [None]:
# example of single action block 
cur = full_df[(full_df.subject==9) & (full_df.block>=10) & (full_df.block<=10)].reset_index()

fig, (ax0,ax1) = plt.subplots(2,1,figsize=(6,3))

sns.scatterplot(ax=ax0,x=np.arange(0,10),y=cur.action,color=color_red,marker='o', s=30)

sns.lineplot(ax=ax0,x=np.arange(0,10),y=1-cur.p_0,lw=2)
sns.scatterplot(ax=ax0,x=np.arange(0,10),y=1-cur.p_0,marker='o', s=30)

sns.lineplot(ax=ax1,x=np.arange(0,10),y=cur.gamma,lw=2,color=color_blue)

ax0.axhline(y=0.5,ls='--',color='k')

for i in range(10):
    ax0.text(x=i-0.17,
             y=cur.action.values[i]+0.1,
             s=cur.reward.values[i],
             size=15,
             color=color_red
            )

for ax in [ax0,ax1]:
    ax.set_xlim(-.2,9.3)
    
ax0.set_ylabel(r'$p(a_{r})$',size=22,labelpad=8)
ax0.tick_params(axis='y', which='major', labelsize=18)

ax1.set_ylabel(r'$\gamma$',size=22)
ax1.tick_params(axis='y', which='major', labelsize=18)

ax0.set_xticks([0,1,2,3,4,5,6,7,8,9])
ax0.set_xticklabels([])
ax1.set_xticks([0,1,2,3,4,5,6,7,8,9])
ax1.tick_params(axis='x', which='major', labelsize=18)
ax1.set_xlabel('Trial',size=20)

ax1.set_yticks([.3,.9])
ax0.set_yticks([0,1])
ax0.set_ylim(-0.1,1.1)
# ax0.text(x=-0.39,y=0.65,s=r'$\uparrow$',size=22)

ax.set_yticklabels([str(x)[1:] for x in np.round(ax1.get_yticks(), 3)])

sns.despine()

plt.savefig('../plots/fig_gamma_a.pdf',bbox_inches='tight')

plt.show()


In [None]:
# mean gamma in single action block ~70 blocks
mid = color_blue
x = []
y = []
for s in range(44):
    for b in range(1,18):
        c = full_df[(full_df.subject==s) & (full_df.block==b)]
        if np.all(c.action.values == c.action.values[0]):
            z = full_df[(full_df.subject==s) & (full_df.block<=b+1) & (full_df.block>=b-1)].copy()
            z['trial'] = np.arange(0,30)
            x.append(z)
        else:
            w = full_df[(full_df.subject==s) & (full_df.block<=b+1) & (full_df.block>=b-1)].copy()
            w['trial'] = np.arange(0,30)
            y.append(w)
                        
mean_gamma = pd.concat(x).groupby('trial')['gamma'].mean().values[6:15]
se_gamma = pd.concat(x).groupby('trial')['gamma'].sem().values[6:15]

mean_gamma_y = pd.concat(y).groupby('trial')['gamma'].mean().values[6:15]
se_gamma_y = pd.concat(y).groupby('trial')['gamma'].sem().values[6:15]


fig, (ax0) = plt.subplots(1,1,figsize=(6,3))

# ax0.errorbar(np.arange(-4,5), mean_gamma, se_gamma, color=mid, ecolor=mid)
ax0.fill_between(np.arange(-4,5),mean_gamma-se_gamma,mean_gamma+se_gamma,alpha=.2)
sns.scatterplot(ax=ax0,x=np.arange(-4,5),y=mean_gamma,s=30,color=mid,edgecolor='k')
sns.lineplot(ax=ax0,x=np.arange(-4,5),y=mean_gamma,color=mid)
ax0.axvline(x=0,ls='--',color='k')

# ax0.fill_between(np.arange(-4,5),mean_gamma_y-se_gamma_y,mean_gamma_y+se_gamma_y,alpha=.2)
# sns.scatterplot(ax=ax0,x=np.arange(-4,5),y=mean_gamma_y,s=30,color='k',edgecolor='k')
# sns.lineplot(ax=ax0,x=np.arange(-4,5),y=mean_gamma_y,color='k')

ax0.set_xticks(np.arange(-4,5,2))
ax0.tick_params(axis='x', which='major', labelsize=20)
ax0.set_xlabel('Trial position from single action block',size=20)

# ax0.set_yticks([0.0,0.2,0.4,0.6]) 
ax0.tick_params(axis='y', which='major', labelsize=20)
ax0.set_ylabel('Direct explor. ' r'$\gamma$',size=20)

# ax0.set_yticks([0.1,0.3,0.5]) 
# ax0.set_ylim(0,0.6) 

fig.tight_layout()
sns.despine()
plt.savefig('../plots/fig_gamma_b.pdf',bbox_inches='tight')

plt.show()
