In [1]:
import numpy as np
import pandas as pd
import os
import mvpa_base_functions as bf
import pickle

from sklearn.model_selection import GroupKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression

In [2]:
phase = 'cond' # experimental phase, 'cond', 'ext', or 'recall'
mask = 'fearnet' # feature to be used, 'fearnet' or 'wholebrain'
trial_block = 1 # trial-block id, 1, 2, 3,or 4
dtfile = f'./sample_data/discovery_{phase}_{mask}_{trial_block}block.pkl'
svfile = f'./sample_results/discovery_{phase}_{mask}_{trial_block}block_accuracy.pkl'
D = pickle.load(open(dtfile, 'rb'))
csm_data = D['csm_data']
csp_data = D['csp_data']

In [3]:
subj_num, feat_num = csm_data.shape
X = np.concatenate((csm_data, csp_data), axis=0)
Y = np.concatenate((0*np.ones(subj_num),1*np.ones(subj_num)), axis=0)
group = np.concatenate((np.arange(subj_num),np.arange(subj_num)), axis=0)

In [4]:
param_num = 20
param_list = np.logspace(-4, 4, num=param_num, base=10)
LR = LogisticRegression(C=1, solver='liblinear')
dclf = Pipeline([('scaler', StandardScaler()), ('clf', LR)])
param_grid = {'clf__C':param_list}
clf = GridSearchCV(dclf, param_grid, n_jobs=10)


gperm_num = 10 # repeat cross-validation for multiple times
cv_num = 5 # 5-fold cross-validation
all_accuracy = np.zeros(gperm_num)

# repeat cross-validation 
for igrp in range(gperm_num):
    rgroup = bf.permute_group(group)

    # cross-validation
    cvg = GroupKFold(n_splits=cv_num) #make sure csp and csm samples from the same participant are both in training/testing set
    cv_prob_pred = np.zeros_like(Y)
    for tridx, tsidx in cvg.split(X, Y, rgroup):
        trX, trY = X[tridx], Y[tridx]
        tsX, tsY = X[tsidx], Y[tsidx]
        clf.fit(trX, trY)
        prob_pred = clf.predict_proba(tsX)
        cv_prob_pred[tsidx] = prob_pred[:,1]

    sc, sc_vec = bf.force_binary_accuracy(Y, cv_prob_pred)
    all_accuracy[igrp] = sc

In [5]:
rand_num = 10 # permutation test, time-consuming, increase to 1000 for real run
cv_num = 5 # 5-fold cross-validation
rand_accuracy = np.zeros(rand_num)

# repeat cross-validation 
for ird in range(rand_num):
    rY = bf.permute_Y(Y)

    # cross-validation
    cvg = GroupKFold(n_splits=cv_num) #make sure csp and csm samples from the same participant are both in training/testing set
    cv_prob_pred = np.zeros_like(rY)
    for tridx, tsidx in cvg.split(X, rY, group):
        trX, trY = X[tridx], rY[tridx]
        tsX, tsY = X[tsidx], rY[tsidx]
        clf.fit(trX, trY)
        prob_pred = clf.predict_proba(tsX)
        cv_prob_pred[tsidx] = prob_pred[:,1]

    sc, sc_vec = bf.force_binary_accuracy(rY, cv_prob_pred)
    rand_accuracy[ird] = sc

In [6]:
# save these variables for plotting
mvals = np.mean(all_accuracy)
evals = np.std(all_accuracy)/np.sqrt(gperm_num)
chance_vals = np.percentile(rand_accuracy, 97.5, axis=0)
pickle.dump({'chance_vals':chance_vals,
             'mvals':mvals,
             'evals':evals},
              open(svfile, 'wb'))

In [7]:
## gather results from all phases and trial-blocks
## run this cell after get results from each phase and trial-block
## i.e., when all files f'./sample_results/discovery_{cphs}_{mask}_{tblk}block_accuracy.pkl' are ready

# mask = 'fearnet' # feature to be used, 'fearnet' or 'wholebrain'
# all_chance_vals = []
# all_mvals = []
# all_evals = []
# for cphs in ['cond', 'ext', 'recall']:
#     for tblk in [1,2,3,4]:
#         file = f'./sample_results/discovery_{cphs}_{mask}_{tblk}block_accuracy.pkl'
#         D = pickle.load(open(file, 'rb'))
#         chance_vals = D['chance_vals']
#         mvals = D['mvals']
#         evals = D['evals']
#         all_mvals.append(mvals)
#         all_evals.append(evals)
#         all_chance_vals.append(chance_vals)
# all_mvals = np.array(all_mvals)
# all_evals = np.array(all_evals)
# all_chance_vals = np.array(all_chance_vals)

# if mask=='wholebrain':
#     savefile = './results/whole_brain_cross_validation_accuracy.pkl'
# elif mask=='fearnet':
#     savefile = './results/threat_circuit_cross_validation_accuracy.pkl'
# pickle.dump({'all_chance_vals':all_chance_vals,
#              'all_mvals':all_mvals,
#              'all_evals':all_evals},
#              open(savefile, 'wb'))