In [None]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from matplotlib import lines
from scipy import stats
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from sklearn.metrics import r2_score

import warnings
warnings.filterwarnings('ignore')

In [None]:
# upload data 
df = pd.read_csv('../data/for_plos.csv')
df = df.rename(columns={'key':'action'})
df_param = pd.read_csv('../results/dezfouli_trnn_param.csv')

df.loc[df["action"] == "R1", "action"] = 0
df.loc[df["action"] == "R2", "action"] = 1

df['action'] = df['action'].astype(int)
df['reward'] = df['reward'].astype(int)
df['block'] = df['block'] - 1

# create unique list of names
UniqueNames = df.ID.unique()
dic = {}
for i in range(101):
    dic[UniqueNames[i]] = i
for i in range(101):
    df.loc[df.ID == UniqueNames[i], 'ID'] = dic[UniqueNames[i]]
df = df.rename(columns={'ID':'subject'}).copy()

label = []
for i in range(101):
    if 'B' in df[df.subject==i].diag.values[0]:
        label.append(0)
    elif 'D' in df[df.subject==i].diag.values[0]:
        label.append(1)
    else:
        label.append(2)    
z = []
b = []
d = []
h = []
idx = [] 

for i in range(101):
    if label[i] == 0:
        b.append(df[df.subject==i]) 
    elif label[i] == 1:
        d.append(df[df.subject==i])
    else:
        h.append(df[df.subject==i])
       
        
for i in range(len(b)):
    z.append(b[i])
    idx.append(np.repeat(i,len(b[i])))
    
last_i = i+1
    
for i in range(len(d)):
    z.append(d[i])
    idx.append(np.repeat(last_i+i,len(d[i])))

last_i = last_i+i+1

for i in range(len(h)):
    z.append(h[i])
    idx.append(np.repeat(last_i+i,len(h[i])))

df = pd.concat(z).reset_index().drop(columns='index')
df['subject'] = np.concatenate(idx)

all_data = []
n_trials = []

for i in range(101):
    cur_df = df[(df['subject']==i)].reset_index()
    all_data.append(cur_df)
    n_trials.append(len(cur_df))


df_full = df.join(df_param.drop(columns='subject'))
df_full['stay'] = (df_full.action.shift(1)==df_full.action).astype(int)

a = []
for i in range(101):
    for j in range(12):
        x = len(df_full[(df_full.subject==i) & (df_full.block==j)])
        a.append(np.arange(x))
        
df_full['trial_b'] = np.concatenate(a)

a = []
for i in range(101):
    x = len(df_full[(df_full.subject==i)])
    a.append(np.arange(x))
                     
df_full['trial'] = np.concatenate(a)

In [None]:
x = sns.blend_palette(['tab:green','#ffc100','#D2292D','grey','#1761B0'])[2]
pla = ['tab:green','#1761B0',x]

dezfouli_ind = pd.read_csv('../results/dezfouli_individual_theoretical.csv')['bce'].values
dezfouli_trnn = pd.read_csv('../results/dezfouli_trnn.csv')['bce'].values
dezfouli_drnn =  pd.read_csv('../results/dezfouli_drnn.csv')['bce'].values
diag = pd.read_csv('../results/dezfouli_individual_theoretical.csv')['diag'].values

df_dezfouli = pd.DataFrame({'dezfouli_theoretical':dezfouli_ind,
                            'dezfouli_trnn':dezfouli_trnn,
                            'dezfouli_drnn':dezfouli_drnn})

df_dezfouli_t = pd.DataFrame({'bce':df_dezfouli.values.T.flatten(),
                             'Model':np.repeat(['QP stationary','t-RNN','d-RNN'],101),
                             'data':np.repeat('Dezfouli',101*3),
                             'Diag':np.tile(diag,3)
                             })

fig , ax = plt.subplots(1,1,figsize=(4,4))

sns.barplot(ax=ax,data=df_dezfouli_t, x='Diag',y='bce',hue='Model',palette=pla,edgecolor='k',errorbar="se")
ax.set_xticklabels(['Bipolar','Depression','Healthy'],fontsize=15)
ax.set_xlabel('',size=18)

legend_elements = [
                Patch([0],[0] ,color=pla[0], label='QP-stationary'),
                Patch([0],[0] ,color=pla[1], label='t-RNN'),
                Patch([0],[0] ,color=pla[2], label='d-RNN')
]
ax.legend(handles=legend_elements,fontsize=14,bbox_to_anchor=(0.6, 1.05))
# loc='lower right',framealpha=0.9)


ax.set_ylim(0,.52)
ax.set_yticks([.1,.4])
ax.set_yticklabels([.1,.4],fontsize=22)
ax.set_ylabel('Error (lower is better)',size=18,labelpad=2)

ax.set_yticklabels([str(x)[1:] for x in np.round(ax.get_yticks(), 3)])

ylim = ax.get_ylim()
yrange = ylim[1] - ylim[0]

h = 0.00*yrange

line = lines.Line2D([-.26,-.26,0.01,0.01], [.47,.47+h,.47+h,.47], lw=1, c='0.2', transform=ax.transData)
line.set_clip_on(False)
ax.add_line(line)
ax.annotate('**', xy=(np.mean([-.26, 0.03]),.46),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=16, clip_on=False, annotation_clip=False)


line = lines.Line2D([.74,.74,1.01,1.01], [.47,.47+0,.47+0,.47], lw=1, c='0.2', transform=ax.transData)
line.set_clip_on(False)
ax.add_line(line)
ax.annotate('**', xy=(np.mean([.74, 1.01]),.46),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=16, clip_on=False, annotation_clip=False)


line = lines.Line2D([1.74,1.74,2.0,2.0], [.31,.31+0,.31+0,.31], lw=1, c='0.2', transform=ax.transData)
line.set_clip_on(False)
ax.add_line(line)
ax.annotate('ns', xy=(np.mean([1.74, 2.01]),.31),
            xytext=(0, 1), textcoords='offset points',
            xycoords='data', ha='center', va='bottom',
            fontsize=14, clip_on=False, annotation_clip=False)

sns.despine()
plt.savefig('../plots/fig_3A.pdf',bbox_inches='tight')
plt.show()


In [None]:
fig,ax1 =plt.subplots(1,1,figsize=(4.5,4))
co = ['#bc193e', '#35a042', '#2435b1']
sns.boxplot(ax=ax1,x=df_full['diag'],y=df_full['kappa'].values,showfliers=False,
            palette=co, 
            saturation=1,
            width=0.6,
#             flierprops={"marker": "o","ms":5,'markerfacecolor':'k','linestyle':'none'},
            medianprops={"color": "black","linewidth":3})


ax1.set_xticklabels(['Bipolar','Depression','Healthy'],fontsize=16)
ax1.set_xlabel('',size=18)

ax1.set_ylim(-.6,0.6)
ax1.set_yticks([-.4,0,.4])
ax1.set_yticklabels([-.4,0,.4],fontsize=18)
ax1.set_ylabel('Preservation '+ r'$\kappa$ '+'\ndistribution',size=18,labelpad=-10)

sns.despine()
plt.savefig('../plots/fig_3B.pdf',bbox_inches='tight')
plt.show()

df_full.groupby('diag')['kappa'].describe()


In [None]:
all_pearson = []
N = 10
for i in range(101):
    x = df_full[df_full.subject==i].kappa.values[N:]
    y = df_full[df_full.subject==i].stay.rolling(N).mean().values[N:]
    all_pearson.append(stats.pearsonr(x,y)[0])
    
p_med = np.median(all_pearson)
print(pd.Series(all_pearson).describe())

plt.figure(figsize=(5,4))
sns.histplot(all_pearson,bins=10,color='#1761B0')
plt.xlabel('Pearson correlation coefficient',size=15)
plt.ylabel('Count',size=14)
plt.axvline(x=p_med,ls='--',color='k')
plt.text(x=p_med-.26,y=18,s='Median\n  %.2f'%p_med,fontsize=14)
plt.tick_params(axis='y', which='major', labelsize=14)
plt.tick_params(axis='x', which='major', labelsize=14)
sns.despine()
plt.savefig('../plots/fig_3C.pdf',bbox_inches='tight')
plt.show()

In [None]:
from qp_pred import *
from qp_fit import *
all_p_0_qs = [] 

qp_fit(all_data[26])
_,p_0 = qp_pred(all_data[26],[0,2.6,1])    
all_p_0_qs.append(p_0)


qp_fit(all_data[59])
_,p_0 = qp_pred(all_data[59],[0,0.7,0])    
all_p_0_qs.append(p_0)

In [None]:
fig ,(ax0,ax1) = plt.subplots(2,1,figsize=(6,4))

actions = df_full[(df_full.subject==26) & (df_full.trial>=493) & (df_full.trial<543)].action
logit_R = df_full[(df_full.subject==26) & (df_full.trial>=493) & (df_full.trial<543)].p_0
logit_Q = all_p_0_qs[0][493:543]
PR = df_full[(df_full.subject==26) & (df_full.trial>=493) & (df_full.trial<543)].kappa
PS = np.repeat(0.48,50)


sns.scatterplot(ax=ax0, x=np.arange(50), y=actions, color='#D2292D', marker='o', s=30)
sns.lineplot(ax=ax0, x=np.arange(50), y=1-logit_R, color='#1761B0', lw=2)
sns.lineplot(ax=ax0, x=np.arange(50), y=1-logit_Q, color='tab:green', lw=2)


ax0.axhline(y=0.5,ls='--',color='k')
ax0.set_xlim(0,50)
ax0.set_xticks([5,25,45])
ax0.set_xticklabels([])

ax0.set_yticks([0,.5,1])
ax0.set_yticklabels(['0','.5','1'],fontsize=18)

ax0.set_ylabel(r'$P(a_{R})$',size=22,labelpad=0)
ax0.tick_params(axis='y', which='major', labelsize=20)

sns.lineplot(ax=ax1,x=np.arange(50),y=PR,color='#1761B0',lw=2)
sns.lineplot(ax=ax1,x=np.arange(50),y=PS,color='tab:green',lw=2)


ax1.axhline(y=0,ls='--',color='k')

ax1.set_xlim(0,50)
ax1.set_xticks([5,25,45])
ax1.set_xticklabels([5,25,45])
ax1.tick_params(axis='x', which='major', labelsize=20)
ax1.set_xlabel('Trial',size=20)

ax1.set_ylim(-.52,.52)
ax1.set_yticks([-.5,0,.5])
ax1.set_yticklabels(['-.5','0','.5'],fontsize=18)
ax1.tick_params(axis='y', which='major', labelsize=20)
ax1.set_ylabel(r'$\kappa$',size=25,labelpad=-10)

sns.despine()
plt.savefig('../plots/fig_3D_left.pdf',bbox_inches='tight')
plt.show()

In [None]:
fig ,(ax0,ax1) = plt.subplots(2,1,figsize=(6,4))

actions = df_full[(df_full.subject==59) & (df_full.trial>=425) & (df_full.trial<475)].action
logit_R = df_full[(df_full.subject==59) & (df_full.trial>=425) & (df_full.trial<475)].p_0
logit_Q = all_p_0_qs[1][425:475]
PR = df_full[(df_full.subject==59) & (df_full.trial>=425) & (df_full.trial<475)].kappa
PS = np.repeat(-0.48,50)

sns.scatterplot(ax=ax0, x=np.arange(50), y=actions, color='#D2292D', marker='o', s=30)
sns.lineplot(ax=ax0, x=np.arange(50), y=1-logit_R, color='#1761B0', lw=2)
sns.lineplot(ax=ax0, x=np.arange(50), y=1-logit_Q, color='tab:green', lw=2)

ax0.axhline(y=0.5,ls='--',color='k')
ax0.set_xlim(0,50)
ax0.set_xticks([5,25,45])
ax0.set_xticklabels([])

ax0.set_yticks([0,.5,1])
ax0.set_yticklabels(['0','.5','1'],fontsize=18)

ax0.set_ylabel(r'$P(a_{R})$',size=22,labelpad=0)
ax0.tick_params(axis='y', which='major', labelsize=20)

sns.lineplot(ax=ax1,x=np.arange(50),y=PR,color='#1761B0',lw=2)
sns.lineplot(ax=ax1,x=np.arange(50),y=PS,color='tab:green',lw=2)

ax1.axhline(y=0,ls='--',color='k')

ax1.set_xlim(0,50)
ax1.set_xticks([5,25,45])
ax1.set_xticklabels([5,25,45])
ax1.tick_params(axis='x', which='major', labelsize=20)
ax1.set_xlabel('Trial',size=20)

ax1.set_ylim(-.52,.52)
ax1.set_yticks([-.5,0,.5])
ax1.set_yticklabels(['-.5','0','.5'],fontsize=18)
ax1.tick_params(axis='y', which='major', labelsize=20)
ax1.set_ylabel(r'$\kappa$',size=25,labelpad=-10)

legend_elements = [
                Line2D([0],[0] ,lw=0, marker='.', color='#D2292D', ms=14, label='True action'),
                Line2D([0],[0] ,lw=3, marker='o',ms=0,  color='tab:green', label='QP-stationary'),
                Line2D([0],[0] ,lw=3, marker='o',ms=0, color='#1761B0', label='t-RNN')
]
ax1.legend(handles=legend_elements,
           bbox_to_anchor=(0.62, 0.15),
           fontsize=18,fancybox=True, framealpha=0.8)

sns.despine()
plt.savefig('../plots/fig_3D_right.pdf',bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,3))

co = ['#bc193e', '#35a042', '#2435b1','#f79605']

z_0 = df_full[df_full.diag=='Bipolar'].groupby('trial_b').mean().kappa.values[:100]
z_1 = df_full[df_full.diag=='Bipolar'].groupby('trial_b').sem().kappa.values[:100]
ax.plot(z_0,color=co[0])
ax.fill_between(np.arange(len(z_0)),z_0+z_1,z_0-z_1, color=co[0],alpha=.2)

z_0 = df_full[df_full.diag=='Depression'].groupby('trial_b').mean().kappa.values[:100]
z_1 = df_full[df_full.diag=='Depression'].groupby('trial_b').sem().kappa.values[:100]
ax.plot(z_0,color=co[1])
ax.fill_between(np.arange(len(z_0)),z_0+z_1,z_0-z_1, color=co[1],alpha=.2)

z_0 = df_full[df_full.diag=='Healthy'].groupby('trial_b').mean().kappa.values[:100]
z_1 = df_full[df_full.diag=='Healthy'].groupby('trial_b').sem().kappa.values[:100]
ax.plot(z_0,color=co[2])
ax.fill_between(np.arange(len(z_0)),z_0+z_1,z_0-z_1, color=co[2],alpha=.2)

ax.set_xlabel('trial_b',size=20)
ax.set_ylabel('$\u03BA$',size=20)
ax.tick_params(axis='both', which='major', labelsize=14)

from matplotlib.patches import Patch
from matplotlib.lines import Line2D


ax.set_xlim(-1,101)
ax.set_ylim(0.05,.45)

ax.text(x=100,y=0.2,s='Bipolar',fontsize=14)
ax.text(x=100,y=0.3,s='Depression',fontsize=14)
ax.text(x=100,y=0.4,s='Healthy',fontsize=14)
ax.set_xticks([0,20,40,60,80,100])
ax.set_xticklabels([0,20,40,60,80,100],fontsize=16)

ax.set_xlabel('Trial',size=20)

plt.tick_params(axis='y', which='major', labelsize=16)
plt.tick_params(axis='x', which='major', labelsize=16)

sns.despine()
plt.savefig('../plots/fig_3E.pdf',bbox_inches='tight')
plt.show()


In [None]:
fig,ax1 =plt.subplots(1,1,figsize=(4.5,4))
co = ['#bc193e', '#35a042', '#2435b1']
sns.boxplot(ax=ax1,x=df_full['diag'],y=df_full['alpha'].values,showfliers=False,
            palette=co, 
            saturation=1,
            width=0.6,
#             flierprops={"marker": "o","ms":5,'markerfacecolor':'k','linestyle':'none'},
            medianprops={"color": "black","linewidth":3})


ax1.set_xticklabels(['Bipolar','Depression','Healthy'],fontsize=16)
ax1.set_xlabel('',size=18)

# ax1.set_ylim(-.6,0.6)
# ax1.set_yticks([-.4,0,.4])
# ax1.set_yticklabels([-.4,0,.4],fontsize=18)
ax1.set_ylabel(r'$\alpha$ '+'distribution',size=18,labelpad=0)

sns.despine()
# plt.savefig('../plots/fig_alpha_distribution.pdf',bbox_inches='tight')
plt.show()

df_full.groupby('diag')['alpha'].describe()


In [None]:
fig,ax1 =plt.subplots(1,1,figsize=(4.5,4))
co = ['#bc193e', '#35a042', '#2435b1']
sns.boxplot(ax=ax1,x=df_full['diag'],y=df_full['beta'].values,showfliers=False,
            palette=co, 
            saturation=1,
            width=0.6,
#             flierprops={"marker": "o","ms":5,'markerfacecolor':'k','linestyle':'none'},
            medianprops={"color": "black","linewidth":3})


ax1.set_xticklabels(['Bipolar','Depression','Healthy'],fontsize=16)
ax1.set_xlabel('',size=18)

# ax1.set_ylim(-.6,0.6)
# ax1.set_yticks([-.4,0,.4])
# ax1.set_yticklabels([-.4,0,.4],fontsize=18)
ax1.set_ylabel(r'$\beta$ '+'distribution',size=18,labelpad=0)

sns.despine()
# plt.savefig('../plots/fig_beta_distribution.pdf',bbox_inches='tight')
plt.show()

df_full.groupby('diag')['beta'].describe()


In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,3))

co = ['#bc193e', '#35a042', '#2435b1','#f79605']

z_0 = df_full[df_full.diag=='Bipolar'].groupby('trial_b').mean().beta.values[:100]
z_1 = df_full[df_full.diag=='Bipolar'].groupby('trial_b').sem().beta.values[:100]
ax.plot(z_0,color=co[0])
ax.fill_between(np.arange(len(z_0)),z_0+z_1,z_0-z_1, color=co[0],alpha=.2)

z_0 = df_full[df_full.diag=='Depression'].groupby('trial_b').mean().beta.values[:100]
z_1 = df_full[df_full.diag=='Depression'].groupby('trial_b').sem().beta.values[:100]
ax.plot(z_0,color=co[1])
ax.fill_between(np.arange(len(z_0)),z_0+z_1,z_0-z_1, color=co[1],alpha=.2)

z_0 = df_full[df_full.diag=='Healthy'].groupby('trial_b').mean().beta.values[:100]
z_1 = df_full[df_full.diag=='Healthy'].groupby('trial_b').sem().beta.values[:100]
ax.plot(z_0,color=co[2])
ax.fill_between(np.arange(len(z_0)),z_0+z_1,z_0-z_1, color=co[2],alpha=.2)

ax.set_xlabel('trial_b',size=16)
ax.set_ylabel(r'$\beta$',size=20)
ax.tick_params(axis='both', which='major', labelsize=14)

from matplotlib.patches import Patch
from matplotlib.lines import Line2D

ax.set_xlim(-1,101)

legend_elements = [
                Line2D([0],[0] ,lw=2, color=co[0], label='Bipolar'),
                Line2D([0],[0] ,lw=2, color=co[1], label='Depression'),
                Line2D([0],[0] ,lw=2, color=co[2], label='Healthy')
                
]
ax.legend(handles=legend_elements,fontsize=14,loc='center left',framealpha=0.5)
ax.set_xticks([0,20,40,60,80,100])
ax.set_xticklabels([0,20,40,60,80,100],fontsize=16)

ax.set_xlabel('Trial',size=16)

plt.tick_params(axis='y', which='major', labelsize=14)
plt.tick_params(axis='x', which='major', labelsize=14)

sns.despine()
plt.savefig('../plots/fig_beta_dynamics.pdf',bbox_inches='tight')
plt.show()


In [None]:
all_pearson = []
N = 10
for i in range(101):
    x = df_full[df_full.subject==i].beta.values[N:]
    y = np.abs(df_full[df_full.subject==i].p_0-0.5).rolling(N).mean().values[N:]
    all_pearson.append(stats.pearsonr(x,y)[0])
#     plt.scatter(x,y)
#     plt.plot(x/20)
#     plt.plot(y)
    plt.show()
    
p_med = np.median(all_pearson)
print(pd.Series(all_pearson).describe())

plt.figure(figsize=(5,4))
sns.histplot(all_pearson,bins=10,color='#1761B0')
plt.xlabel('Pearson correlation coefficient',size=15)
plt.ylabel('Count',size=14)
plt.axvline(x=p_med,ls='--',color='k')
plt.text(x=p_med-0.06,y=18.5,s='Median\n  %.2f'%p_med,fontsize=14)
plt.tick_params(axis='y', which='major', labelsize=14)
plt.tick_params(axis='x', which='major', labelsize=14)
sns.despine()
plt.savefig('../plots/fig_beta_pearson.pdf',bbox_inches='tight')
plt.show()

In [None]:
pla = ['tab:green','#1761B0']
bb = pd.read_csv('../results/dezfouli_individual_theoretical.csv')['beta'].values

N = 10

fig, ax = plt.subplots(3,2,figsize=(14,6),gridspec_kw={'width_ratios': [4, 1]}) 

b_n=0

ax[0,0].plot(np.arange(len(df_full[df_full.subject==b_n])-N),
        np.abs(df_full[df_full.subject==b_n].p_0-0.5).rolling(N).mean().values[N:],'#D2292D',lw=3)

ax[0,0].plot(np.arange(len(df_full[df_full.subject==b_n])-N),
         df_full[df_full.subject==b_n].beta.values[N:]/20,
         color='#1761B0',lw=2,ls=(0, (5, 1)))


d_n = 50
         
ax[1,0].plot(np.arange(len(df_full[df_full.subject==d_n])-N),
         np.abs(df_full[df_full.subject==d_n].p_0-0.5).rolling(N).mean().values[N:],'#D2292D',lw=3)


ax[1,0].plot(np.arange(len(df_full[df_full.subject==d_n])-N),
         df_full[df_full.subject==d_n].beta.values[N:]/20,
         color='#1761B0',lw=2,ls=(0, (5, 1)))
         

h_n = 76
         
ax[2,0].plot(np.arange(len(df_full[df_full.subject==h_n])-N),
         np.abs(df_full[df_full.subject==h_n].p_0-0.5).rolling(N).mean().values[N:],'#D2292D',lw=3)


ax[2,0].plot(np.arange(len(df_full[df_full.subject==h_n])-N),
         df_full[df_full.subject==h_n].beta.values[N:]/20,
         color='#1761B0',lw=2,ls=(0, (5, 1)))


ax[0,0].set_xticks([])
ax[1,0].set_xticks([])
ax[2,0].set_xticks([0,200,400,600,800,1000])
ax[2,0].tick_params(axis='x', which='major', labelsize=18)
ax[2,0].set_xlabel('Trial',size=18) # 

ax[0,0].set_xlim(-1,len(df_full[df_full.subject==b_n]))
ax[0,0].set_ylim(0.0,0.5)

ax[1,0].set_xlim(-1,len(df_full[df_full.subject==d_n]))
ax[1,0].set_ylim(0.0,0.5)

ax[2,0].set_xlim(-1,1050)
ax[2,0].set_ylim(0.0,0.5)

ax[0,0].set_yticks([0.1,0.4])
ax[0,0].set_yticklabels([0.1,0.4],size=16)
ax[0,0].tick_params(axis='y', which='major',labelsize=16)
ax[0,0].set_ylabel(r'$\beta$',size=22,labelpad=-10)

ax[1,0].set_yticks([0.1,0.4])
ax[1,0].set_yticklabels([0.1,0.4],size=16)
ax[1,0].tick_params(axis='y', which='major',labelsize=16)
ax[1,0].set_ylabel(r'$\beta$',size=22,labelpad=-10)

ax[2,0].set_yticks([0.1,0.4])
ax[2,0].set_yticklabels([0.1,0.4],size=16)
ax[2,0].tick_params(axis='y', which='major',labelsize=16)
ax[2,0].set_ylabel(r'$\beta$',size=22,labelpad=-10)

x = np.abs(df_full[df_full.subject==b_n].p_0-0.5).rolling(N).mean().values[N:]
y = df_full[df_full.subject==b_n].beta.values[N:]/20
sns.regplot(ax=ax[0,1],x=x, y=y,x_bins=10,color=pla[1],line_kws={'lw':2})
sns.lineplot(ax=ax[0,1],x=[0,0.5],y=[0,0.5],ls='--',color='#D2292D',lw=3)
ro,p = stats.pearsonr(x=x , y=y)
ax[0,1].text(x=0.3,y=0.1,s=r'$r^{2}=%.3f$'%ro,color=pla[1],size=14)

x = np.abs(df_full[df_full.subject==d_n].p_0-0.5).rolling(N).mean().values[N:]
y = df_full[df_full.subject==d_n].beta.values[N:]/20
sns.regplot(ax=ax[1,1],x=x, y=y,x_bins=10,color=pla[1],line_kws={'lw':2})
sns.lineplot(ax=ax[1,1],x=[0,0.5],y=[0,0.5],ls='--',color='#D2292D',lw=3)
ro,p = stats.pearsonr(x=x , y=y)
ax[1,1].text(x=0.3,y=0.1,s=r'$r^{2}=%.3f$'%ro,color=pla[1],size=14)


x = np.abs(df_full[df_full.subject==h_n].p_0-0.5).rolling(N).mean().values[N:]
y = df_full[df_full.subject==h_n].beta.values[N:]/20
sns.regplot(ax=ax[2,1],x=x, y=y,x_bins=10,color=pla[1],line_kws={'lw':2})
sns.lineplot(ax=ax[2,1],x=[0,0.5],y=[0,0.5],ls='--',color='#D2292D',lw=3)
ro,p = stats.pearsonr(x=x , y=y)
ax[2,1].text(x=0.3,y=0.1,s=r'$r^{2}=%.3f$'%ro,color=pla[1],size=14)

ax[0,1].set_xticks([])
ax[1,1].set_xticks([])

ax[2,1].set_xticks([0.1,0.4])
ax[2,1].tick_params(axis='x', which='major', labelsize=16)
ax[2,1].set_xlabel('Stay MA',size=18) # 


for i in range(3):
    ax[i,1].set_xlim(0,0.5)
    ax[i,1].set_ylim(0,0.5)
    ax[i,1].set_yticks([0.1,0.4])
    ax[i,1].set_yticklabels([0.1,0.4],size=16)
    ax[i,1].set_ylabel(r'$\beta$',size=22,labelpad=-10)


legend_elements = [
                Line2D([0],[0] ,lw=2, color='#D2292D', label=r'$P(a_t)$ MA'),
                Line2D([0],[0] ,lw=2, color=pla[1], label='t-RNN')
                
]
ax[2,0].legend(handles=legend_elements,fontsize=14,loc='upper left',framealpha=1)

sns.despine()
plt.savefig('../plots/fig_random_cp_ma.pdf',bbox_inches='tight')
plt.show()

In [None]:
pla = ['tab:green','#1761B0']
kk = pd.read_csv('../results/dezfouli_individual_theoretical.csv')['kappa'].values

N = 10

fig, ax = plt.subplots(3,2,figsize=(14,6),gridspec_kw={'width_ratios': [4, 1]}) 

b_n = 0

ax[0,0].plot(np.arange(len(df_full[df_full.subject==b_n])-N),
         df_full[df_full.subject==b_n]['stay'].rolling(N).mean()[N:].values,
         '#D2292D',lw=3)

ax[0,0].plot(np.arange(len(df_full[df_full.subject==b_n])-N),
         df_full[df_full.subject==b_n].kappa.values[N:]+0.5,
         color='#1761B0',lw=2,ls=(0, (5, 1)))
         
ax[0,0].plot(np.arange(len(df_full[df_full.subject==b_n])-N),
         np.repeat(kk[b_n],len(df_full[df_full.subject==b_n])-N),
         'tab:green',lw=2,ls=(0, (5, 1)))


d_n = 34
         
ax[1,0].plot(np.arange(len(df_full[df_full.subject==d_n])-N),
         df_full[df_full.subject==d_n]['stay'].rolling(N).mean()[N:].values,
         '#D2292D',lw=3)

ax[1,0].plot(np.arange(len(df_full[df_full.subject==d_n])-N),
         df_full[df_full.subject==d_n].kappa.values[N:]+0.5,
         color='#1761B0',lw=2,ls=(0, (5, 1)))
         
ax[1,0].plot(np.arange(len(df_full[df_full.subject==d_n])-N),
         np.repeat(kk[d_n],len(df_full[df_full.subject==d_n])-N),
         'tab:green',lw=2,ls=(0, (5, 1)))

h_n = 93 # 93
         
ax[2,0].plot(np.arange(len(df_full[df_full.subject==h_n])-N),
         df_full[df_full.subject==h_n]['stay'].rolling(N).mean()[N:].values,
         '#D2292D',lw=3)

ax[2,0].plot(np.arange(len(df_full[df_full.subject==h_n])-N),
         df_full[df_full.subject==h_n].kappa.values[N:]+0.5,
         color='#1761B0',lw=2,ls=(0, (5, 1)))
         
ax[2,0].plot(np.arange(len(df_full[df_full.subject==h_n])-N),
         np.repeat(kk[h_n],len(df_full[df_full.subject==h_n])-N),
         'tab:green',lw=2,ls=(0, (5, 1)))



ax[0,0].set_xticks([])
ax[1,0].set_xticks([])
ax[2,0].set_xticks([0,200,400,600,800,1000])
ax[2,0].tick_params(axis='x', which='major', labelsize=18)
ax[2,0].set_xlabel('Trial',size=18) # 

ax[0,0].set_xlim(-1,len(df_full[df_full.subject==b_n]))
ax[0,0].set_ylim(-0.1,1.1)

ax[1,0].set_xlim(-1,len(df_full[df_full.subject==d_n]))
ax[1,0].set_ylim(-0.1,1.1)

ax[2,0].set_xlim(-1,1050)
ax[2,0].set_ylim(-0.1,1.1)


ax[0,0].set_yticks([0.1,0.9])
ax[0,0].set_yticklabels([-0.4,0.4],size=16)
ax[0,0].tick_params(axis='y', which='major',labelsize=16)
ax[0,0].set_ylabel(r'$\kappa$',size=22,labelpad=-10)

ax[1,0].set_yticks([0.1,0.9])
ax[1,0].set_yticklabels([-0.4,0.4],size=16)
ax[1,0].tick_params(axis='y', which='major',labelsize=16)
ax[1,0].set_ylabel(r'$\kappa$',size=22,labelpad=-10)

ax[2,0].set_yticks([0.1,0.9])
ax[2,0].set_yticklabels([-0.4,0.4],size=16)
ax[2,0].tick_params(axis='y', which='major',labelsize=16)
ax[2,0].set_ylabel(r'$\kappa$',size=22,labelpad=-10)


x = df_full[df_full.subject==b_n]['stay'].rolling(N).mean()[N:].values
z = np.repeat(kk[b_n]-0.5,len(df_full[df_full.subject==b_n])-N) + np.random.uniform(-0.01,0.01,
                                                                                len(df_full[df_full.subject==b_n])-N)
y = df_full[df_full.subject==b_n].kappa.values[N:]

sns.regplot(ax=ax[0,1],x=x, y = z,
            scatter=True,color=pla[0],line_kws={'lw':2},
            x_estimator=np.mean)

sns.regplot(ax=ax[0,1],x=x, y = y,
            scatter=True,color=pla[1],line_kws={'lw':2},
            x_estimator=np.mean)

sns.lineplot(ax=ax[0,1],x=[0,1],y=[-.5,.5],ls='--',color='#D2292D',lw=3)

ro,p = stats.pearsonr(x=x, y = z)
ax[0,1].text(x=0.55,y=-0.3,s=r'$r^{2}=%.3f$'%ro,color=pla[0],size=14)
ro,p = stats.pearsonr(x=x , y = y)
ax[0,1].text(x=0.55,y=-0.45,s=r'$r^{2}=%.3f$'%ro,color=pla[1],size=14)


x = df_full[df_full.subject==d_n]['stay'].rolling(N).mean()[N:].values
z = np.repeat(kk[d_n]-0.5,len(df_full[df_full.subject==d_n])-N) + np.random.uniform(-0.01,0.01,
                                                                                len(df_full[df_full.subject==d_n])-N)
y = df_full[df_full.subject==d_n].kappa.values[N:]

sns.regplot(ax=ax[1,1],x=x, y = z,
            scatter=True,color=pla[0],line_kws={'lw':2},
            x_estimator=np.mean)

sns.regplot(ax=ax[1,1],x=x, y = y,
            scatter=True,color=pla[1],line_kws={'lw':2},
            x_estimator=np.mean)

sns.lineplot(ax=ax[1,1],x=[0,1],y=[-.5,.5],ls='--',color='#D2292D',lw=3)

ro,p = stats.pearsonr(x=x, y = z)
ax[1,1].text(x=0.55,y=-0.3,s=r'$r^{2}=%.3f$'%ro,color=pla[0],size=14)
ro,p = stats.pearsonr(x=x , y = y)
ax[1,1].text(x=0.55,y=-0.45,s=r'$r^{2}=%.3f$'%ro,color=pla[1],size=14)

sns.despine()


x = df_full[df_full.subject==h_n]['stay'].rolling(N).mean()[N:].values
z = np.repeat(kk[h_n]-0.5,len(df_full[df_full.subject==h_n])-N) + np.random.uniform(-0.01,0.01,
                                                                                len(df_full[df_full.subject==h_n])-N)
y = df_full[df_full.subject==h_n].kappa.values[N:]

sns.regplot(ax=ax[2,1],x=x, y = z,
            scatter=True,color=pla[0],line_kws={'lw':2},
            x_estimator=np.mean)

sns.regplot(ax=ax[2,1],x=x, y = y,
            scatter=True,color=pla[1],line_kws={'lw':2},
            x_estimator=np.mean)

sns.lineplot(ax=ax[2,1],x=[0,1],y=[-.5,.5],ls='--',color='#D2292D',lw=3)

ro,p = stats.pearsonr(x=x, y = z)
ax[2,1].text(x=0.55,y=-0.3,s=r'$r^{2}=%.3f$'%ro,color=pla[0],size=14)
ro,p = stats.pearsonr(x=x , y = y)
ax[2,1].text(x=0.55,y=-0.45,s=r'$r^{2}=%.3f$'%ro,color=pla[1],size=14)


ax[0,1].set_xticks([])
ax[1,1].set_xticks([])
ax[2,1].set_xticks([0,1])
ax[2,1].tick_params(axis='x', which='major', labelsize=18)
ax[2,1].set_xlabel('Stay MA',size=18) # 

ax[0,1].set_xlim(-0.1,1.1)
ax[0,1].set_ylim(-0.51,0.51)

ax[1,1].set_xlim(-0.1,1.1)
ax[1,1].set_ylim(-0.51,0.51)

ax[2,1].set_xlim(-0.1,1.1)
ax[2,1].set_ylim(-0.55,0.55)

ax[0,1].set_yticks([-0.4,0.4])
ax[0,1].set_yticklabels([-0.4,0.4],size=16)
ax[0,1].set_ylabel(r'$\kappa$',size=22,labelpad=-10)

ax[1,1].set_yticks([-0.4,0.4])
ax[1,1].set_yticklabels([-0.4,0.4],size=16)
ax[1,1].set_ylabel(r'$\kappa$',size=22,labelpad=-10)

ax[2,1].set_yticks([-0.4,0.4])
ax[2,1].set_yticklabels([-0.4,0.4],size=16)
ax[2,1].set_ylabel(r'$\kappa$',size=22,labelpad=-10)


legend_elements = [
                Line2D([0],[0] ,lw=2, color='#D2292D', label='Stay MA'),
                Line2D([0],[0] ,lw=2, color=pla[0], label='QP-stationary'),
                Line2D([0],[0] ,lw=2, color=pla[1], label='t-RNN')
                
]
ax[2,0].legend(handles=legend_elements,fontsize=14,loc='lower left',framealpha=0.5)

sns.despine()
# plt.savefig('../plots/fig_stay_ma.pdf',bbox_inches='tight')
plt.show()