In [1]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from scipy.stats import pearsonr
from gershman_sim import *
from gershman_fit import *

import warnings
import multiprocessing as mp
import time
warnings.filterwarnings('ignore')

In [2]:
# simulate agents
n_trials = 10 
n_blocks = 20
n_agent = 500

t_parameters = np.zeros(shape=(n_agent,2))
all_data = [] 

for agent in range(n_agent):
    beta = np.random.uniform(0,5)
    gamma = np.random.uniform(0,5)
    t_parameters[agent] = beta,gamma
    sim = gershman_sim(agent,t_parameters[agent],n_blocks,n_trials)
    df = pd.DataFrame(sim)
    all_data.append(df)

In [None]:
# recover parameters
start = time.time()
pool = mp.Pool(processes=mp.cpu_count())
fit_arr = pool.map(gershman_fit,all_data)
end = time.time()
print(end - start)
pool.close()

r_parameters = np.array([fit_arr[i].x for i in range(n_agent)])

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

sns.despine()

sns.lineplot(ax=ax1,x=[0,5],y=[0,5],ls='--',color='crimson')
sns.regplot(ax=ax1,x=t_parameters[:,0],y=r_parameters[:,0],color='royalblue')
ax1.set_xlabel('True beta')
ax1.set_ylabel('Recovered beta')

sns.lineplot(ax=ax2,x=[0,5],y=[0,5],ls='--',color='crimson')
sns.regplot(ax=ax2,x=t_parameters[:,1],y=r_parameters[:,1],color='royalblue')
ax2.set_xlabel('True gamma')
ax2.set_ylabel('Recovered gamma')
plt.show()


print(pearsonr(x=t_parameters[:,0],y=r_parameters[:,0]))
print(pearsonr(x=t_parameters[:,1],y=r_parameters[:,1]))

In [2]:
df = pd.read_csv('../data/data2.csv')
df['action'] = df['choice']-1
df['block'] = df['block']-1
df['subject'] = df['subject']-1

all_data=[]

for i in range(44):
    cur_df = df[(df['subject']==i)].reset_index()
    all_data.append(cur_df)
    
# recover parameters of all agents 
start = time.time()
pool = mp.Pool(processes=mp.cpu_count())
fit_arr = pool.map(gershman_fit,all_data)
end = time.time()
print(end - start)
pool.close()

from gershman_pred import *

def bce_loss(y_hat,y_true):
    eps = 1e-10
    return (-((y_true*(np.log(y_hat + eps)) + (1-y_true)*(np.log(1-y_hat + eps)))))
all_bce_0 = [] 
for i in range(44):
    cur_df = df[(df['subject']==i)].reset_index()
    a,b = gershman_pred(cur_df,fit_arr[i].x)
    all_bce_0.append(bce_loss(1-b, cur_df.action.values).mean())
np.array(all_bce_0).mean()

pd.DataFrame(np.array([fit_arr[i].x for i in range(44)]),
             columns=['beta','gamma'])

4.6123528480529785


In [20]:
# LOO
from gershman_pred import *
from tqdm import tqdm

def bce_loss(y_hat,y_true):
    eps = 1e-10
    return (-((y_true*(np.log(y_hat + eps)) + (1-y_true)*(np.log(1-y_hat + eps)))))

df = pd.read_csv('../data/data2.csv')
df['action'] = df['choice']-1
df['block'] = df['block']-1
df['trial'] = df['trial']-1
df['subject'] = df['subject']-1
for i in range(44):
    cur_df = df[(df['subject']==i)].reset_index()
    all_data.append(cur_df)

T = 2

all_bce_0 = []
param_0 = []
ll_0 = []

for i in tqdm(range(44)):    
    for _ in range(T):
        train = df[(df['subject']!=i)].reset_index()
        test = df[(df['subject']==i)].reset_index()

        res = gershman_fit(train)
        a,b = gershman_pred(test,res.x)
        ll_0.append(res.fun)
        param_0.append(res.x)
        all_bce_0.append(bce_loss(1-b, test.action.values).mean())
tar = np.array(all_bce_0).reshape(44,T)
ind = np.argmin((np.array(ll_0).reshape(44,T)),axis=1)
test_0 = np.array([tar[i,idx] for i,idx in enumerate(ind)])

test_0.mean()

100%|███████████████████████████████████████████| 44/44 [29:52<00:00, 40.74s/it]
