In [None]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from scipy.stats import pearsonr
from tqdm import tqdm

from gershman_sim import *
from gershman_fit import *
from gershman_pred import *
import pickle

import warnings
import multiprocessing as mp
import time
warnings.filterwarnings('ignore')

In [None]:
df = pd.read_csv('../data/data2.csv')
df['action'] = df['choice']-1
df['block'] = df['block']-1
df['subject'] = df['subject']-1

all_data=[]

for i in range(44):
    cur_df = df[(df['subject']==i)].reset_index()
    all_data.append(cur_df)
    
# fit parameters of all subj 
fit_arr = []
for i in range(44):    
    pool = mp.Pool(processes=mp.cpu_count())
    temp_fit_arr = pool.map(gershman_fit,[all_data[i],all_data[i],all_data[i],all_data[i],all_data[i]])
    pool.close()
    x = np.array([temp_fit_arr[j].fun for j in range(5)])
    fit_arr.append(temp_fit_arr[np.argmin(x)])

In [None]:
def bce_loss(y_hat,y_true):
    eps = 1e-10
    return (-((y_true*(np.log(y_hat + eps)) + (1-y_true)*(np.log(1-y_hat + eps)))))

bce_arr = []
p_r2_arr = []
all_p_0 = []
norm_ll_arr = []

for i in range(44):
    cur_df = df[(df['subject']==i)].reset_index()
    acc,p_0 = gershman_pred(cur_df,fit_arr[i].x)
    all_p_0.append(p_0)
    loss = bce_loss(1-p_0, cur_df.action.values)
    bce_arr.append(loss.mean())
    p_r2_arr.append( 1- (np.array(loss.sum()) / (-len(cur_df)*np.log(0.5))))
    norm_ll_arr.append(np.exp(-loss.mean()))
    

ind_beta = np.array([fit_arr[i].x[0] for i in range(44)])
ind_gamma = np.array([fit_arr[i].x[1] for i in range(44)])
ind_nll = np.array([fit_arr[i].fun for i in range(44)])
subj = np.arange(44)

# save files
pd.DataFrame({'subject':subj,
              'beta':ind_beta,
              'gamma':ind_gamma,
              'nll':ind_nll,
              'bce':bce_arr,
              'psr2':p_r2_arr,
              'norm_ll':norm_ll_arr}).to_csv('../results/gershman_individual_theoretical.csv',index=False)

with open('../results/gershman_individual_p_0_theoretical.pickle', 'wb') as handle:
    pickle.dump(all_p_0, handle, protocol=pickle.HIGHEST_PROTOCOL)
