In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

import multiprocessing as mp
import pickle 
import warnings 
warnings.filterwarnings('ignore')

from imports import*
from utils import *
from logistic_regression import *
from rnn import *

from hybrid_sim import *
from hybrid_fit import *
from hybrid_predict import *

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

from scipy.stats import pearsonr , spearmanr

cuda


In [2]:
df_tst = pd.read_csv('../data/TST.csv')
df_iq = pd.read_csv('../data/wasi.csv')
ids = df_iq['SubjectID'].values

def update_data_frame(df): 
    # this funcation add to the data frame 3 columns:
    # 1 - prev_reward : if the last trail was rewarded
    # 2 - transation_prev : if the last transation was rare or common
    # 3 - stay probs 
    df['prev_reward'] = df['reward'].shift(1,fill_value=0)
    df['transition_prev'] = df['transition_type'].shift(1,fill_value=0)
    df['stay'] = df['action_stage_1'].shift(1)==df['action_stage_1']
    
def sigmoid(x):
    return(1.0 / (1.0 + np.exp(-x)))

def inverse_sigmoid(y):
    return(np.log(y/(1-y)))

all_dfs = []

for i in ids:
    if i == 25510:
        continue
    df = df_tst[df_tst['subjectID'] == i]
    df.reset_index(inplace=True)
    all_dfs.append(df)

new_all_dfs = []

for i in all_dfs:
    if i.shape[0] == 443: # 121 322 443
        new_all_dfs.append(i)
        
block_0 = []

for i in new_all_dfs:
    df = i[i['measurement'] == 'baseline']
    df.reset_index(inplace=True)
    block_0.append(df)

block_1 = []

for i in new_all_dfs:
    df = i[i['measurement'] == 'followup']
    df.reset_index(inplace=True)
    block_1.append(df)

block_2 = []

for i in new_all_dfs:
    df = i[i['measurement'] == 'six_month']
    df.reset_index(inplace=True)
    block_2.append(df)
    
ids = []
for i in block_0:
    ids.append(i['subjectID'].unique()[0])

IQs = []
for i in ids:
    IQs.append(df_iq[df_iq['SubjectID'] == i].IQ.values[0])
    
for df_i in block_0:
    df_i.rename(columns=
                    {
                    'choice1':'action_stage_1',
                    'choice2':'action_stage_2',
                    '2nd_stage_state':'state_of_stage_2',
                    'transition':'transition_type',
                    'reward':'reward'},inplace=True
                    )
    
    df_i.drop(columns={
        'level_0', 'index','measurement','key1', 'key2','rt1', 'rt2',
       'iti','1st_stage_stim_left', '1st_stage_stim_right',
       '2nd_stage_stim_left', '2nd_stage_stim_right',
         'p1','p2', 'p3', 'p4'},inplace=True)
    
    df_i['action_stage_1']-=1
    df_i['action_stage_2']%=2
    df_i['state_of_stage_2']-=1
    update_data_frame(df_i)
    df_i.dropna(inplace=True)
    
for df_i in block_1:
    df_i.rename(columns=
                    {
                    'choice1':'action_stage_1',
                    'choice2':'action_stage_2',
                    '2nd_stage_state':'state_of_stage_2',
                    'transition':'transition_type',
                    'reward':'reward'},inplace=True
                    )
    
    df_i.drop(columns={
        'level_0', 'index','measurement','key1', 'key2','rt1', 'rt2',
       'iti','1st_stage_stim_left', '1st_stage_stim_right',
       '2nd_stage_stim_left', '2nd_stage_stim_right',
         'p1','p2', 'p3', 'p4'},inplace=True)
    
    df_i['action_stage_1']-=1
    df_i['action_stage_2']%=2
    df_i['state_of_stage_2']-=1
    update_data_frame(df_i)
    df_i.dropna(inplace=True)
    

for df_i in block_2:
    df_i.rename(columns=
                    {
                    'choice1':'action_stage_1',
                    'choice2':'action_stage_2',
                    '2nd_stage_state':'state_of_stage_2',
                    'transition':'transition_type',
                    'reward':'reward'},inplace=True
                    )
    
    df_i.drop(columns={
        'level_0', 'index','measurement','key1', 'key2','rt1', 'rt2',
       'iti','1st_stage_stim_left', '1st_stage_stim_right',
       '2nd_stage_stim_left', '2nd_stage_stim_right',
         'p1','p2', 'p3', 'p4'},inplace=True)
    
    df_i['action_stage_1']-=1
    df_i['action_stage_2']%=2
    df_i['state_of_stage_2']-=1
    update_data_frame(df_i)
    df_i.dropna(inplace=True)
    
    
all_blocks = [block_0,block_1,block_2]

In [3]:
# num of agent
num_of_agents = len(all_blocks[0])

# num of block
num_of_block = 3

# for cross valdation 
array = np.arange(num_of_block)
cv = [np.roll(array,i) for i in range(num_of_block)]
cv = np.array(cv)

models = {

    'hybrid':[ configuration_parameters_hybrid,
               hybrid_sim,
               hybrid_fit,
               hybrid_predict],
    

}

def bce(y_hat,y_true):
    eps = 1e-7
    return -np.sum( y_true*np.log(y_hat+eps) + (1-y_true)*np.log(1-y_hat+eps) )

In [None]:
data_results = {
    
   'agent': [],
   'train_block': [],
    
   'train_nll_hybrid' : [],
   'val_nll_hybrid' : [],
   'test_nll_hybrid': [], 

    
}

K = 5
N = num_of_agents
data_results['agent'].append(np.tile(np.arange(0,N),len(cv)))
data_results['train_block'].append(np.repeat(cv[:,0],N))

for m in tqdm(models): 

    print(f'*** Fit with {m} ***')
    
    for train, val, test in cv:
        print(f'*** train {train} | val {val} | test {test} ***')
        
        # fit k times 
        fit_res = []
        for _ in range(K):
            pool = mp.Pool(processes=mp.cpu_count())
            fit = pool.map(models[m][2], all_blocks[train])
            pool.close()
            fit_res.append(fit)
            
        # best train/validation nll    
        all_nll_train = np.zeros(shape=(K,N))
        all_nll_val = np.zeros(shape=(K,N))
        all_nll_test = np.zeros(shape=(K,N))
        best_parameters = [] 
        for k in range(K):
            for n in range(N):
                _ , y_hat, _ = models[m][3](all_blocks[train][n], fit_res[k][n].x)
                nLL = bce(1-y_hat, all_blocks[train][n]['action_stage_1'].values)
                all_nll_train[k,n] =  nLL
                
                _ , y_hat, _ = models[m][3](all_blocks[val][n], fit_res[k][n].x)
                nLL = bce(1-y_hat, all_blocks[val][n]['action_stage_1'].values)
                all_nll_val[k,n] = nLL
                
                _ , y_hat, _ = models[m][3](all_blocks[test][n], fit_res[k][n].x)
                nLL = bce(1-y_hat, all_blocks[test][n]['action_stage_1'].values)
                all_nll_test[k,n] = nLL
                
        best_train = all_nll_train.min(axis=0)
        best_val = all_nll_val.min(axis=0)
        indx = np.argmin(all_nll_val,axis=0)
        best_test = np.array([all_nll_test[indx[n],n] for n in range(N)])

        data_results[f'train_nll_{m}'].append(best_train)
        data_results[f'val_nll_{m}'].append(best_val)
        data_results[f'test_nll_{m}'].append(best_test)
        

for k in data_results:
    data_results[k] = np.concatenate(data_results[k])
df_the = pd.DataFrame(data_results)


In [None]:
K = 1
N = num_of_agents # 3 models

data_results_lr = {
        
   'train_nll_lr' : [],
   'val_nll_lr' : [],
   'test_nll_lr': [], 
    
}

for train, val, test in cv:
    print(f'*** train {train} | val {val} | test {test} ***')
    
    all_nll_train = np.zeros(shape=(K,N))
    all_nll_val = np.zeros(shape=(K,N))
    all_nll_test = np.zeros(shape=(K,N))
    
    fit_res = []    
    for k in range(K):
        cur_res = []
        for n in range(N):
            X, y = preprocess_logistic_regression(all_blocks[train][n],lag=k+1)
            clf, inter, coef = fit_logistic_regression(X,y)
            cur_res.append(clf)
        fit_res.append(cur_res)
            
    # best train/validation nll    
    for k in range(K):
        for n in range(N):
            clf = fit_res[k][n]
            
            # train
            X, y = preprocess_logistic_regression(all_blocks[train][n],lag=k+1)
            if clf == None:
                nLL = -np.log(.5)*200
            else:
                y_hat = clf.predict_proba(X)[:,0]
                nLL = bce(1-y_hat, all_blocks[train][n]['action_stage_1'].values)
            
            all_nll_train[k,n] = nLL
            
            # validation
            X, y = preprocess_logistic_regression(all_blocks[val][n],lag=k+1)
            if clf == None:
                nLL = -np.log(.5)*200
            else:
                y_hat = clf.predict_proba(X)[:,0]
                nLL = bce(1-y_hat, all_blocks[val][n]['action_stage_1'].values)
            
            all_nll_val[k,n] = nLL

            X, y = preprocess_logistic_regression(all_blocks[test][n],lag=k+1)            
            if clf == None:
                nLL = -np.log(.5)*200
            else:
                y_hat = clf.predict_proba(X)[:,0]
                nLL = bce(1-y_hat, all_blocks[test][n]['action_stage_1'].values)
                
            all_nll_test[k,n] = nLL
            
    best_train = all_nll_train.min(axis=0)
    best_val = all_nll_val.min(axis=0)
    indx = np.argmin(all_nll_val,axis=0)
    best_test = np.array([all_nll_test[indx[n],n] for n in range(N)])
    
    data_results_lr[f'train_nll_lr'].append(best_train)
    data_results_lr[f'val_nll_lr'].append(best_val)
    data_results_lr[f'test_nll_lr'].append(best_test)
    
for k in data_results_lr:
    data_results_lr[k] = np.concatenate(data_results_lr[k])
df_lr = pd.DataFrame(data_results_lr)
            

In [None]:
with open('../results/iq.pickle', 'wb') as handle:
    pickle.dump(IQs, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
df = pd.concat([df_the,df_lr],axis=1)
df.to_csv('../results/hybrid7_lr1_emp.csv')


In [4]:
N = num_of_agents # 3 models

INPUT_SIZE = 5
OUTPUT_SIZE = 2
LERANING_RATE = 0.001

hidden_size = 5
num_layers = 1
epochs = 1000

loss_train, loss_val, loss_test  = [], [], []
ll_train, ll_val, ll_test = [], [], []

for n in tqdm(range(N)):
    for train, val, test in cv:

        train_data = behavior_dataset(all_blocks[train][n])
        val_data = behavior_dataset(all_blocks[val][n])
        test_data = behavior_dataset(all_blocks[test][n])

        train_loader = DataLoader(train_data,shuffle=False,batch_size=len(train_data))
        val_loader = DataLoader(val_data,shuffle=False,batch_size=len(val_data))
        test_loader = DataLoader(test_data,shuffle=False,batch_size=len(test_data))
        
        rnn = GRU_NN(INPUT_SIZE, hidden_size, num_layers, OUTPUT_SIZE)
        rnn, train_loss, train_ll, val_loss, val_ll, test_loss, test_ll = train_model(rnn,
                                                                                train_loader,
                                                                                val_loader,
                                                                                test_loader,
                                                                                epochs=epochs,
                                                                                lr=LERANING_RATE) 
                                                                                                                                       
        loss_train.append(train_loss)
        loss_val.append(val_loss)
        loss_test.append(test_loss)
        
        ll_train.append(train_ll)
        ll_val.append(val_ll)
        ll_test.append(test_ll)
        
    print('Done agent',n)
    
    
with open('../results/loss_train.pickle', 'wb') as handle:
    pickle.dump(loss_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open('../results/loss_val.pickle', 'wb') as handle:
    pickle.dump(loss_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open('../results/loss_test.pickle', 'wb') as handle:
    pickle.dump(loss_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open('../results/ll_train.pickle', 'wb') as handle:
    pickle.dump(ll_train, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('../results/ll_val.pickle', 'wb') as handle:
    pickle.dump(ll_val, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
with open('../results/ll_test.pickle', 'wb') as handle:
    pickle.dump(ll_test, handle, protocol=pickle.HIGHEST_PROTOCOL)
        

  2%|█▌                                                                                 | 1/54 [01:00<53:17, 60.32s/it]

Done agent 0


  4%|███                                                                                | 2/54 [02:05<54:55, 63.37s/it]

Done agent 1


  6%|████▌                                                                              | 3/54 [03:03<51:47, 60.94s/it]

Done agent 2


  7%|██████▏                                                                            | 4/54 [03:53<46:54, 56.29s/it]

Done agent 3


  9%|███████▋                                                                           | 5/54 [04:40<43:14, 52.95s/it]

Done agent 4


 11%|█████████▏                                                                         | 6/54 [05:26<40:40, 50.85s/it]

Done agent 5


 13%|██████████▊                                                                        | 7/54 [06:13<38:47, 49.52s/it]

Done agent 6


 15%|████████████▎                                                                      | 8/54 [07:00<37:18, 48.66s/it]

Done agent 7


 17%|█████████████▊                                                                     | 9/54 [07:46<35:54, 47.87s/it]

Done agent 8


 19%|███████████████▏                                                                  | 10/54 [08:33<34:51, 47.54s/it]

Done agent 9


 20%|████████████████▋                                                                 | 11/54 [09:20<33:56, 47.37s/it]

Done agent 10


 22%|██████████████████▏                                                               | 12/54 [10:07<33:00, 47.16s/it]

Done agent 11


 24%|███████████████████▋                                                              | 13/54 [10:53<32:09, 47.05s/it]

Done agent 12


 26%|█████████████████████▎                                                            | 14/54 [11:39<31:09, 46.74s/it]

Done agent 13


 28%|██████████████████████▊                                                           | 15/54 [12:26<30:20, 46.69s/it]

Done agent 14


 30%|████████████████████████▎                                                         | 16/54 [13:12<29:30, 46.59s/it]

Done agent 15


 31%|█████████████████████████▊                                                        | 17/54 [13:59<28:45, 46.63s/it]

Done agent 16


 33%|███████████████████████████▎                                                      | 18/54 [14:46<28:03, 46.77s/it]

Done agent 17


 35%|████████████████████████████▊                                                     | 19/54 [15:34<27:23, 46.96s/it]

Done agent 18


 37%|██████████████████████████████▎                                                   | 20/54 [16:20<26:35, 46.92s/it]

Done agent 19


 39%|███████████████████████████████▉                                                  | 21/54 [17:07<25:44, 46.80s/it]

Done agent 20


 41%|█████████████████████████████████▍                                                | 22/54 [17:54<24:57, 46.78s/it]

Done agent 21


 43%|██████████████████████████████████▉                                               | 23/54 [18:41<24:13, 46.88s/it]

Done agent 22


 44%|████████████████████████████████████▍                                             | 24/54 [19:27<23:24, 46.81s/it]

Done agent 23


 46%|█████████████████████████████████████▉                                            | 25/54 [20:14<22:35, 46.75s/it]

Done agent 24


 48%|███████████████████████████████████████▍                                          | 26/54 [21:01<21:49, 46.75s/it]

Done agent 25


 50%|█████████████████████████████████████████                                         | 27/54 [21:48<21:04, 46.83s/it]

Done agent 26


 52%|██████████████████████████████████████████▌                                       | 28/54 [22:34<20:15, 46.75s/it]

Done agent 27


 54%|████████████████████████████████████████████                                      | 29/54 [23:21<19:29, 46.78s/it]

Done agent 28


 56%|█████████████████████████████████████████████▌                                    | 30/54 [24:08<18:43, 46.80s/it]

Done agent 29


 57%|███████████████████████████████████████████████                                   | 31/54 [24:55<17:57, 46.85s/it]

Done agent 30


 59%|████████████████████████████████████████████████▌                                 | 32/54 [25:42<17:13, 47.00s/it]

Done agent 31


 61%|██████████████████████████████████████████████████                                | 33/54 [26:29<16:26, 46.97s/it]

Done agent 32


 63%|███████████████████████████████████████████████████▋                              | 34/54 [27:16<15:40, 47.00s/it]

Done agent 33


 65%|█████████████████████████████████████████████████████▏                            | 35/54 [28:03<14:48, 46.77s/it]

Done agent 34


 67%|██████████████████████████████████████████████████████▋                           | 36/54 [28:50<14:03, 46.85s/it]

Done agent 35


 69%|████████████████████████████████████████████████████████▏                         | 37/54 [29:37<13:17, 46.92s/it]

Done agent 36


 70%|█████████████████████████████████████████████████████████▋                        | 38/54 [30:24<12:31, 46.96s/it]

Done agent 37


 72%|███████████████████████████████████████████████████████████▏                      | 39/54 [31:10<11:42, 46.86s/it]

Done agent 38


 74%|████████████████████████████████████████████████████████████▋                     | 40/54 [31:57<10:55, 46.81s/it]

Done agent 39


 76%|██████████████████████████████████████████████████████████████▎                   | 41/54 [32:44<10:08, 46.78s/it]

Done agent 40


 78%|███████████████████████████████████████████████████████████████▊                  | 42/54 [33:30<09:20, 46.67s/it]

Done agent 41


 80%|█████████████████████████████████████████████████████████████████▎                | 43/54 [34:17<08:34, 46.81s/it]

Done agent 42


 81%|██████████████████████████████████████████████████████████████████▊               | 44/54 [35:03<07:46, 46.61s/it]

Done agent 43


 83%|████████████████████████████████████████████████████████████████████▎             | 45/54 [35:50<06:59, 46.67s/it]

Done agent 44


 85%|█████████████████████████████████████████████████████████████████████▊            | 46/54 [36:37<06:13, 46.67s/it]

Done agent 45


 87%|███████████████████████████████████████████████████████████████████████▎          | 47/54 [37:23<05:26, 46.57s/it]

Done agent 46


 89%|████████████████████████████████████████████████████████████████████████▉         | 48/54 [38:10<04:40, 46.72s/it]

Done agent 47


 91%|██████████████████████████████████████████████████████████████████████████▍       | 49/54 [38:56<03:52, 46.56s/it]

Done agent 48


 93%|███████████████████████████████████████████████████████████████████████████▉      | 50/54 [39:43<03:06, 46.51s/it]

Done agent 49


 94%|█████████████████████████████████████████████████████████████████████████████▍    | 51/54 [40:30<02:19, 46.57s/it]

Done agent 50


 96%|██████████████████████████████████████████████████████████████████████████████▉   | 52/54 [41:16<01:33, 46.67s/it]

Done agent 51


 98%|████████████████████████████████████████████████████████████████████████████████▍ | 53/54 [42:03<00:46, 46.59s/it]

Done agent 52


100%|██████████████████████████████████████████████████████████████████████████████████| 54/54 [42:50<00:00, 47.59s/it]

Done agent 53



