In [1]:
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
import pickle, os, glob
import numpy as np
import sys, copy, time

import mne
mne.set_log_level('WARNING')
from EEG.info import info_exp
from EEG.converter import converter_mne
from EEG.utils import load, save, data_prep, plot_confusion_matrix, plot_mult_conf_matrices,psd

import scipy.signal as spsig

from mne.decoding import CSP,UnsupervisedSpatialFilter, Vectorizer
from mne.time_frequency import psd_multitaper
from pyriemann.classification import MDM, TSclassifier
from pyriemann.estimation import covariances, XdawnCovariances


from sklearn import metrics
from sklearn.pipeline import Pipeline

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

In [2]:
def psd(raw,picks,tmax,fmax):
    psds, freqs = psd_multitaper(raw, low_bias=True, tmax=4.499,
                              fmax=50., proj=True, picks=picks,
                              n_jobs=1)
    psds = 10 * np.log10(psds)
    psds_mean = psds.mean(0)
    psds_std = psds.std(0)
    
    return psds_mean, psds_std, freqs

In [8]:
root = "C:\\eeg\\01exp\\"
labtests = ["20161129_DBS_001","20161209_KPS_001","20161210_GSH_001"]
names = ["exp_data_aligned_start_0.cls","exp_data_aligned_start_150.cls"]
filters = [{'min':0,'max':20},{'min':6,'max':16},{'min':10,'max':16}]

decoders = [
        CSP(n_components=2, reg='ledoit_wolf', log=True),
        'Riemann'
      ]

clfs = [
    LinearDiscriminantAnalysis(),
    GaussianNB(),
    KNeighborsClassifier(),
    DecisionTreeClassifier(),
    SVC(),
    MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(4, 10), random_state=1),
    LogisticRegression()
]

# специфичный параметр продолжительности эпохи 
tepochs = 4.499

statdata = {}
stattable = []

# по всем испытуемым
for idlab, lab in enumerate(labtests):
    statdata.update({lab:{}})
    
    # по всем отсечкам времени переходов 
    for idname, name in enumerate(names):
        statdata[lab].update({name:{}})
        test = load(root + lab,name)
        data  = converter_mne(test)
       
        # по всем диапазонам фильтрации
        for idpsd, pfilter in enumerate(filters):
            freqlabel = str(pfilter['min']) +"-" +str(pfilter['max'])
            statdata[lab][name].update({freqlabel:{}})
            epochs = data.train_epochs()
            raw = data.train_raw()
            # отфильтруем данные
            raw.filter(pfilter['min'],pfilter['max'],
                       phase='zero',filter_length='auto', 
                       fir_window='hamming',l_trans_bandwidth='auto', 
                       h_trans_bandwidth='auto')
            picks = mne.pick_types(info=data.mne_info(), meg=False, eeg=True, misc=False)
            # вычислим для вывода графика спектральной плотности мощьности
            psds_mean, psds_std, freqs = psd(raw,picks,tepochs,50.)
            statdata[lab][name][freqlabel].update({'psd':{'psds_mean':psds_mean,'psds_std':psds_std,'freqs':freqs}})
                
                
            ### decoder
            for dec in decoders:
                statdata[lab][name][freqlabel].update({dec:{}})
                csp = dec

                for clf in clfs:
                    clfname = copy.copy(clf.__class__.__name__)
                    statdata[lab][name][freqlabel][dec].update({clfname:{}})
                    
                    test_list = []
                    for i in range(0,data.num_tests):
                        epochs_test = data.test_epochs(i)
                        Y_test = epochs_test.events[:,-1]
                        X_test = epochs_test.get_data()
                        test_list.append({'Y_test':Y_test,'X_test':X_test})
                    
                    start_time = time.time()
                    if(csp == 'Riemann'):
                        Y_train = epochs.events[:,-1][:test_list[0]['Y_test'].shape[0]]
                        X_train = epochs.get_data()[:test_list[0]['Y_test'].shape[0]]
                        X_train = data_prep(X_train, 1000, pfilter['min'], pfilter['max'])
                        clf = TSclassifier(clf=clf)
                        # compute covariance matricestest_list[i]['Y_test']
                        data_cov = covariances(X_train, estimator='oas')
                        clf.fit(data_cov , Y_train)

                    else:
                        Y_train = epochs.events[:,-1]
                        X_train = epochs.get_data()
                        X_train = data_prep(X_train, 1000, pfilter['min'], pfilter['max'])
                        X_train = csp.fit_transform(X_train, Y_train)
                        clf.fit(X_train, Y_train)

                    tfitv = time.time() - start_time
                    
                    scores = []
                    tscores = []
                    classification_report = []
                    confusion_matrix = []
                    
                    for i in range(0, len(test_list)):
                        if(csp == 'Riemann'):
                            X_test = data_prep(test_list[i]['X_test'][:test_list[0]['Y_test'].shape[0]], 1000, pfilter['min'], pfilter['max'])
                            X_test = covariances(X_test, estimator='oas')
                            Y_test = test_list[i]['Y_test'][:test_list[0]['Y_test'].shape[0]]
                        else:
                            X_test = data_prep(test_list[i]['X_test'], 1000, pfilter['min'], pfilter['max'])
                            X_test = csp.transform(X_test)
                        start_time = time.time()
                        scores.append(clf.score(X_test, test_list[i]['Y_test']))
                        tscores.append(time.time() - start_time)

                        predicted =  clf.predict(X_test)
                        classification_report.append(metrics.classification_report(test_list[i]['Y_test'], predicted))
                        confusion_matrix.append(metrics.confusion_matrix(test_list[i]['Y_test'],predicted))
                    
                    #plot_mult_conf_matrices(confusion_matrix,[u'Rest',u'left',u'right'])
                    
                    statdata[lab][name][freqlabel][dec][clfname].update({
                        'max': np.max(scores),
                        'mean': np.mean(scores),
                        'min': np.min(scores),
                        'disp': np.std(scores),
                        'conf_matrix': confusion_matrix
                    })
                    
                    if(dec.__class__.__name__ != 'str'):
                        decs =  dec.__class__.__name__
                    
                    stattable.append([
                        lab,name,freqlabel,decs,clfname,
                        np.max(scores), np.mean(scores),
                        np.min(scores), np.std(scores),
                        tfitv, np.mean(tscores),
                        confusion_matrix,
                        {'psds_mean':psds_mean,'psds_std':psds_std,'freqs':freqs}
                        
                    ])
                    
                    
save(root,'result.dat',stattable)
print "success"

success
