In [None]:
import numpy as np
import pandas as pd
import os
import base_functions as bf
import pickle
from sklearn import svm
from sklearn.model_selection import RepeatedStratifiedKFold, GridSearchCV
from sklearn import metrics

from nilearn import connectome, plotting
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
dtfile = './data/dataset1_connectivity.pkl'
D = pickle.load(open(dtfile, 'rb'))
all_session_subjects = D['all_session_subjects']
all_session_conn_vec = D['all_session_conn_vec']
all_session = D['all_session']
info_DF = D['info_DF']

In [None]:
care_sess = 0 # focus on the first session, i.e., morning after sleep manipulation
dep_labels = bf.get_subject_info(info_DF, all_session_subjects[care_sess], ['deprive_labels'])
# dep_labels: 0->normal sleep; 1->partial sleep deprivation; 2->sleep deprivation
idx = np.logical_or(dep_labels==0, dep_labels==2)
X = all_session_conn_vec[care_sess][idx]
X = np.arctanh(X) #Fisher r to z
X[np.isnan(X)] = 0
Y = dep_labels[idx]
Y[Y==Y.min()] = 0
Y[Y!=Y.min()] = 1
subjects = all_session_subjects[care_sess]
subjects = [subjects[i] for i,v in enumerate(idx) if v==True]

In [None]:
rep_num = 10 # repreat cross validation for multiple times

all_scores = np.zeros(rep_num)
all_predict_prob = np.zeros((X.shape[0],rep_num))
all_test_index = []
for irep in range(rep_num):
    print(irep)
    cvg = RepeatedStratifiedKFold(n_splits=10, n_repeats=1, random_state=irep)
    
    cv_prob_pred = np.zeros_like(Y)
    test_index = []
    # cross validation
    for tridx, tsidx in cvg.split(X, Y):
        trX, trY = X[tridx], Y[tridx]
        tsX, tsY = X[tsidx], Y[tsidx]
        parameters = {'C':np.linspace(0.00001,10000,20)}
        clf = GridSearchCV(svm.SVC(kernel='linear', probability=True), parameters, n_jobs=-1)

        clf.fit(trX, trY)
        prob_pred = clf.decision_function(tsX)
        cv_prob_pred[tsidx] = prob_pred
        test_index.append(tsidx)
    sc = metrics.roc_auc_score(Y, cv_prob_pred)

    all_predict_prob.append(cv_prob_pred)
    all_test_index.append(test_index)
    all_scores.append(sc)

pickle.dump({'all_predict_prob':all_predict_prob,
             'all_scores':all_scores,
             'all_test_index':all_test_index,
             'subjects':subjects,
             'Y':Y,},
             open(f'./results/accuracy_10-10CV_D1Morning.pkl', 'wb'))

In [None]:
## Train model for generalization
parameters = {'C':np.linspace(0.00001,10000,20)}
clf = GridSearchCV(svm.SVC(kernel='linear', probability=True), parameters, n_jobs=-1)


clf.fit(X, Y)
coef_vals = clf.best_estimator_.coef_
patterns = bf.weight_transform(X, coef_vals)


pickle.dump({'clf':clf.best_estimator_,
             'patterns':patterns,
             'coef_vals':coef_vals,
             'subjects':subjects},
             open('./Discovery_data_trained_model.pkl', 'wb'))