In [16]:
import numpy as np
import pandas as pd
import os
import mvpa_base_functions as bf
import pickle

from sklearn.model_selection import GroupKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression

In [17]:
phase = 'cond'
mask = 'fearnet'
trial_block = 1

In [18]:
# load discovery dataset, and use it for classifier training
D = pickle.load(open(f'./sample_data/discovery_{phase}_{mask}_{trial_block}block.pkl', 'rb'))
csm_data = D['csm_data']
csp_data = D['csp_data']

subj_num, feat_num = csm_data.shape
X = np.concatenate((csm_data, csp_data), axis=0)
Y = np.concatenate((0*np.ones(subj_num),1*np.ones(subj_num)), axis=0)
group = np.concatenate((np.arange(subj_num),np.arange(subj_num)), axis=0)

param_num = 10
param_list = np.logspace(-4, 4, num=param_num, base=10)
LR = LogisticRegression(C=1, solver='liblinear')
dclf = Pipeline([('scaler', StandardScaler()), ('clf', LR)])
param_grid = {'clf__C':param_list}


In [19]:
## option 1, train the classifier based on the sample discovery dataset
clf = GridSearchCV(dclf, param_grid, n_jobs=10)
clf.fit(X, Y);

## option 2, load the trained classifier based on the whole discovery dataset
# if mask=='fearnet':
#     file = f'./models/{phase}_threat_circuit_model.pkl'
# elif mask=='wholebrain':
#     file = f'{phase}_whole_brain_model.pkl'
# D = pickle.load(open(file, 'rb'))
# all_clf = D['all_clf']
# clf = all_clf[trial_block-1]

In [20]:
# load validation dataset, and apply the trained classifier to it
valid_file = f'./sample_data/validation1_{phase}_{mask}_{trial_block}block.pkl'
svfile = f'./sample_results/validation1_{phase}_{mask}_{trial_block}block_accuracy.pkl' # file to save the results
D = pickle.load(open(valid_file, 'rb'))
val_csm_data = D['csm_data']
val_csp_data = D['csp_data']
val_subj_num = csm_data.shape[0]
val_X = np.concatenate((val_csm_data, val_csp_data), axis=0)
val_Y = np.concatenate((0*np.ones(val_subj_num),1*np.ones(val_subj_num)), axis=0)

val_prob_pred = clf.predict_proba(val_X)
val_accuracy, val_acc_vec = bf.force_binary_accuracy(val_Y, val_prob_pred[:,1])

In [21]:
# permutation, this is time consuming
perm_num = 10 # increase this value to at least 1000 for real analysis
val_accuracy_permutation = np.zeros(perm_num)
for iperm in range(perm_num):
    rY = bf.permute_Y(Y)
    rclf = GridSearchCV(dclf, param_grid, n_jobs=10)
    rclf.fit(X, rY);
    
    rprob_pred = rclf.predict_proba(val_X)
    racc, _ = bf.force_binary_accuracy(val_Y, rprob_pred[:,1])
    val_accuracy_permutation[iperm] = racc

In [24]:
# save these variables for plotting
mvals = np.mean(val_acc_vec)
evals = np.std(val_acc_vec)/np.sqrt(val_acc_vec.shape[0])
chance_vals = np.percentile(val_accuracy_permutation, 97.5, axis=0)
pickle.dump({'chance_vals':chance_vals,
             'mvals':mvals,
             'evals':evals},
              open(svfile, 'wb'))

In [None]:
## gather results from all phases and trial-blocks
## run this cell after get results from each phase and trial-block
## i.e., when all files f'./sample_results/validation1_{phase}_{mask}_{trial_block}block_accuracy.pkl' are ready

# mask = 'fearnet' # feature to be used, 'fearnet' or 'wholebrain'
# all_chance_vals = []
# all_mvals = []
# all_evals = []
# for cphs in ['cond', 'ext', 'recall']:
#     for tblk in [1,2,3,4]:
#         file = f'./sample_results/validation1_{phase}_{mask}_{trial_block}block_accuracy.pkl'
#         D = pickle.load(open(file, 'rb'))
#         chance_vals = D['chance_vals']
#         mvals = D['mvals']
#         evals = D['evals']
#         all_mvals.append(mvals)
#         all_evals.append(evals)
#         all_chance_vals.append(chance_vals)
# all_mvals = np.array(all_mvals)
# all_evals = np.array(all_evals)
# all_chance_vals = np.array(all_chance_vals)

# if mask=='wholebrain':
#     savefile = './results/whole_brain_external_validation_accuracy.pkl'
# elif mask=='fearnet':
#     savefile = './results/threat_circuit_external_validation_accuracy.pkl'
# pickle.dump({'all_chance_vals':all_chance_vals,
#              'all_mvals':all_mvals,
#              'all_evals':all_evals},
#              open(savefile, 'wb'))