In [None]:
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns 
from tqdm import tqdm
import pickle
from scipy.stats import pearsonr

from qp_fit import *
from qp_pred import *
from utils import *

import warnings
import multiprocessing as mp
import time
warnings.filterwarnings('ignore')

In [None]:
# upload data 
df = pd.read_csv('../data/for_plos.csv')
df = df.rename(columns={'key':'action'})

df.loc[df["action"] == "R1", "action"] = 0
df.loc[df["action"] == "R2", "action"] = 1

df['action'] = df['action'].astype(int)
df['reward'] = df['reward'].astype(int)
df['block'] = df['block'] - 1

# create unique list of names
UniqueNames = df.ID.unique()
dic = {}
for i in range(101):
    dic[UniqueNames[i]] = i
for i in range(101):
    df.loc[df.ID == UniqueNames[i], 'ID'] = dic[UniqueNames[i]]
df = df.rename(columns={'ID':'subject'}).copy()

label = []
for i in range(101):
    if 'B' in df[df.subject==i].diag.values[0]:
        label.append(0)
    elif 'D' in df[df.subject==i].diag.values[0]:
        label.append(1)
    else:
        label.append(2)    
z = []
b = []
d = []
h = []
idx = [] 

for i in range(101):
    if label[i] == 0:
        b.append(df[df.subject==i]) 
    elif label[i] == 1:
        d.append(df[df.subject==i])
    else:
        h.append(df[df.subject==i])
       
        
for i in range(len(b)):
    z.append(b[i])
    idx.append(np.repeat(i,len(b[i])))
    
last_i = i+1
    
for i in range(len(d)):
    z.append(d[i])
    idx.append(np.repeat(last_i+i,len(d[i])))

last_i = last_i+i+1

for i in range(len(h)):
    z.append(h[i])
    idx.append(np.repeat(last_i+i,len(h[i])))

df = pd.concat(z).reset_index().drop(columns='index')
df['subject'] = np.concatenate(idx)


all_data = []
n_trials = []

for i in range(101):
    cur_df = df[(df['subject']==i)].reset_index()
    all_data.append(cur_df)
    n_trials.append(len(cur_df))

bipolar_data = df[df.subject<33]
depression_data = df[(df.subject>=33) & (df.subject<67)]
healthy_data = df[(df.subject>=67)]

# fit parameters of all subj 
fit_arr = []
for i in range(101):    
    pool = mp.Pool(processes=mp.cpu_count())
    temp_fit_arr = pool.map(qp_fit,[all_data[i],all_data[i],all_data[i],all_data[i],all_data[i]])
    pool.close()
    x = np.array([temp_fit_arr[j].fun for j in range(5)])
    fit_arr.append(temp_fit_arr[np.argmin(x)])

In [None]:
# pred all subj actions

def bce_loss(y_hat,y_true):
    eps = 1e-10
    return (-((y_true*(np.log(y_hat + eps)) + (1-y_true)*(np.log(1-y_hat + eps)))))

bce_arr = []
p_r2_arr = []
all_p_0 = []
norm_ll_arr = []
for i in range(101):
    
    cur_df = df[(df['subject']==i)].reset_index()
    acc,p_0 = qp_pred(cur_df,fit_arr[i].x)
    
    all_p_0.append(p_0)
    loss = bce_loss(1-p_0, cur_df.action.values)
    bce_arr.append(loss.mean())
    p_r2_arr.append( 1- (np.array(loss.sum()) / (-len(cur_df)*np.log(0.5))))
    norm_ll_arr.append(np.exp(-loss.mean()))
    
ind_alpha = np.array([fit_arr[i].x[0] for i in range(101)])
ind_beta = np.array([fit_arr[i].x[1] for i in range(101)])
ind_kappa = np.array([fit_arr[i].x[2] for i in range(101)])
ind_ntrials = np.array(n_trials)
ind_nll = np.array([fit_arr[i].fun for i in range(101)])
subj = np.arange(101)

diag = np.concatenate([np.repeat(df.diag.unique()[0],33),
        np.repeat(df.diag.unique()[1],34),
        np.repeat(df.diag.unique()[2],34)])

# save files
pd.DataFrame({'subject':subj,
              'diag':diag,
              'alpha':ind_alpha,
              'beta':ind_beta,
              'kappa':ind_kappa,
              'nll':ind_nll,
              'n_trials':ind_ntrials,
              'bce':bce_arr,
              'psr2':p_r2_arr,
              'norm_ll':norm_ll_arr}).to_csv('../results/dezfouli_individual_theoretical.csv',index=False)

with open('../results/dezfouli_individual_p_0_theoretical.pickle', 'wb') as handle:
    pickle.dump(all_p_0, handle, protocol=pickle.HIGHEST_PROTOCOL)
