In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import scipy.io as io

from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import GroupKFold, StratifiedKFold

import sys
from IPython.display import clear_output

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'

def verbose(text):    
    clear_output(wait=True)    
    print(text)
    sys.stdout.flush() 
    
def errorfill(x, y, yerr, color=None, label=None, alpha_fill=0.3, ax=None):
    """
    Makes a nice plot with error bars
    """
    ax = ax if ax is not None else plt.gca()
    if color is None:
        color = ax._get_lines.get_next_color()
    if np.isscalar(yerr) or len(yerr) == len(y):
        ymin = y - yerr
        ymax = y + yerr
    elif len(yerr) == 2:
        ymin, ymax = yerr
    ax.plot(x, y, color=color, label=label)
    ax.fill_between(x, ymax, ymin, color=color, label=label, alpha=alpha_fill)
    return ax

In [None]:
filename = '../shareddata/ERPs.npz'
# filename = '../shareddata/ERPs-b50.npz'
# filename = '../shareddata/ERPs-b100.npz'
data = np.load(filename, allow_pickle=True)

In [None]:
X_asd_ = data['ERPs_ASD'] # [participants, catch/target, C1/C2/C3/C4]
X_typ_ = data['ERPs_TYP']

In [None]:
n_asd = X_asd_.shape[0]
n_typ = X_typ_.shape[0]
n_elec, n_t, _ = X_asd_[0,0,0].shape # electrodes , timesteps, trials



In [None]:
%%timeit
# average each unique condition over trials, put them all together
# mean is over the third index (trials), then stack them together, using third index squeezed to 1,
# so it becomes the total number of subjects
# the + concatenates the lists
X_all = np.zeros((2,4,n_asd+n_typ,n_t,n_elec)) # catch/target | C1/C2/C3/C4
for k in range(4):
    X_all[0,k] = np.concatenate([X_asd_[i,0,k].mean(2, keepdims=True)
                                 for i in range(n_asd)]+
                                [X_typ_[i,0,k].mean(2, keepdims=True)
                                 for i in range(n_typ)],2).T
    X_all[1,k] = np.concatenate([X_asd_[i,1,k].mean(2, keepdims=True)
                                 for i in range(n_asd)]+
                                [X_typ_[i,1,k].mean(2, keepdims=True)
                                 for i in range(n_typ)],2).T
# Find the subjects that do not have any NaN    
idx = np.isnan(X_all.mean((0,1,3,4)))==False
X_all = X_all[:,:,idx]


In [None]:
%%timeit
# exercise: do my alternative X_all
# means separately, concatenate, remove NaN subjects

def temp_mean(X_):
    "mean over trials"
    X_ret = np.zeros((X_.shape+X_.flat[0].shape[0:-1])) # all dims except last
    for ijk in np.ndindex(X_.shape):
        X_ret[ijk] = np.mean(X_[ijk],axis=(X_.ndim-1)) #mean over last dimension
    return X_ret
    
temp = np.concatenate((temp_mean(X_asd_),temp_mean(X_typ_)),axis=0)
# find subjects that do not have nan values anywhere
idx_subj_notnan = np.logical_not(np.any( np.isnan(temp), axis=tuple(range(1,temp.ndim))))
# keep only good subjects
X_all_alt = temp[idx_subj_notnan,]

In [None]:
%%timeit
def temp_mean_super(X_in,X_):
    "mean over trials"
    for ijk in np.ndindex(X_.shape):
         #X_in[ijk]=np.mean(X_[ijk],axis=(X_.ndim-1)) #mean over last dimension 
         np.mean(X_[ijk],axis=(X_.ndim-1),out=X_in[ijk] )   
    return
nasd =  X_asd_.shape[0]
nsubj = nasd + X_typ_.shape[0]

shape_out = (nsubj,*(X_asd_.shape[1:]+X_asd_.flat[0].shape[0:-1]))
mydtype = X_asd_.flat[0].dtype 
temp = np.empty( shape_out , dtype=mydtype )

temp_mean_super(temp[0:nasd,],X_asd_)
temp_mean_super(temp[nasd:,], X_typ_)
# find subjects that do not have nan values anywhere
idx_subj_notnan = np.logical_not(np.any( np.isnan(temp), axis=tuple(range(1,temp.ndim))))
# keep only good subjects
X_all_alt2 = temp[idx_subj_notnan,]


In [None]:

def temp_mean_extreme(X_ret,X_):
    "mean over trials"
    nopes = (slice(None),slice(None))
    for ijk in np.ndindex(X_.shape):
        #print(nopes+ijk)
        #Xfill = X_ret[nopes+ijk]
        #np.mean(X_[ijk],axis=(X_.ndim-1),out=Xfill )
        X_ret[nopes+ijk]=np.mean(X_[ijk],axis=(X_.ndim-1))
    return
nasd =  X_asd_.shape[0]
nsubj = nasd + X_typ_.shape[0]

shape_out = (X_asd_.flat[0].shape[0:-1]+(nsubj,)+X_asd_.shape[1:])
mydtype = X_asd_.flat[0].dtype 
temp = np.empty( shape_out , dtype=mydtype )

temp_mean_extreme(temp[:,:,0:nasd,:,:],X_asd_)
temp_mean_extreme(temp[:,:,nasd:,:,:], X_typ_)
# find subjects that do not have nan values anywhere
idx_subj_notnan = np.logical_not(np.any( np.isnan(temp), axis=tuple(range(1,temp.ndim))))
# keep only good subjects
X_all_alt2 = temp[idx_subj_notnan,]


In [None]:
# test that they are the same
sames = np.full(X_all.shape,False)
for ijetc in np.ndindex(X_all.shape):
    (i_tc, i_conds, i_par, i_t, i_elec)=ijetc
    sames[ijetc] = X_all[ijetc] == X_all_alt2[i_par,i_tc,i_conds,i_elec,i_t]

print(f"Test if they are exactly the same -> {sames.all()}")



In [None]:
y = np.concatenate((np.zeros(n_asd),np.ones(n_typ))) # training labels
y = y[idx]
# target/catch , conditions, participants, timebins, electrodes
n_tc, n_conds, n_participants, n_t, n_elec = X_all.shape

In [None]:
# this is the size of raw data for one subject and one condition
n_t*n_elec

In [None]:
# test_C

In [None]:
pca = PCA()
lr = 1
# number of components
test_pca = np.arange(1,38,1)
# parameter thingy
test_C = np.logspace(-5, 0, 6)
n_C = test_C.shape[0]
n_pca = test_pca.shape[0]

if lr:
    clf = LogisticRegression(max_iter=5000, tol=1.0)
    param_grid = {
        'pca__n_components': test_pca,
        'clf__C': test_C,
    }
else:
    clf = LinearDiscriminantAnalysis(tol=1e-3)
    param_grid = {
        'pca__n_components': test_pca,
    }

# clf = QuadraticDiscriminantAnalysis(tol=1e-3)
pipe = Pipeline(steps=[('pca', pca), ('clf', clf)])
# search = GridSearchCV(pipe, param_grid, cv=5, n_jobs=1, verbose=10)

In [None]:
search_all = np.empty((n_tc,n_conds), dtype=object)
for i in range(n_tc):
    for j in range(n_conds):
        X = X_all[i,j].reshape(n_participants,n_t*n_elec)**2
        
        # using also squared signal improve performances
        #X = np.concatenate([X_all[i,j].reshape(n_participants,n_t*n_elec),
        #                   X_all[i,j].reshape(n_participants,n_t*n_elec)**2],
        #                   axis=1)
        
        X -= X.mean(0, keepdims=True)
        X /= X.std(0, keepdims=True)
        search_all[i,j] = GridSearchCV(pipe, param_grid,
                                       cv=5, n_jobs=15, verbose=0)
        search_all[i,j].fit(X, y);
        verbose('%i,%i'%(i,j))

In [None]:
fig, ax = plt.subplots(n_tc,n_conds,figsize=(12,6))
for i in range(n_tc):
    for j in range(n_conds):
        print("Best parameter (CV score=%0.3f):" % search_all[i,j].best_score_)
        print(search_all[i,j].best_params_)
        if lr:
            for k in range(test_C.shape[0]):
                errorfill(test_pca,search_all[i,j].
                          cv_results_['mean_test_score']
                          .reshape(test_C.shape[0],test_pca.shape[0])[k].T,
                          search_all[i,j].cv_results_['std_test_score']
                          .reshape(test_C.shape[0],test_pca.shape[0])[k].T
                              /np.sqrt(5), ax=ax[i,j])
                ax[i,j].set_ylim(0.4,0.9)
            
        else:
            errorfill(test_pca,search_all[i,j].cv_results_['mean_test_score'],
                      search_all[i,j].cv_results_['std_test_score']/np.sqrt(5),
                      ax=ax[i,j])    
        ax[i,j].set_ylim(0.4,0.95)
        


## Things to try
- Concatenate the conditions C1/C2/C3/C4
- Use non-linear methods like kernel methods (but we are already overfitting)
- Craft our own Quadratic Discriminant Analysis with diagonal cov matrices
- Use L1 penalty + logistic regression (highly relevant to find a relevant low
dimensional space)
- Use brute force PCA components selection (using all intervals [i,j])
- do channel selection and time windows selection

In [None]:
# pulling all conditions together (and time windows selection ??) 
Xct = np.zeros(2, dtype=object)
Xct[0] = np.concatenate([X_all[0,j].reshape(n_participants,n_t*n_elec)
                     for j in range(n_conds)], 1)
Xct[1] = np.concatenate([X_all[1,j].reshape(n_participants,n_t*n_elec)
                     for j in range(n_conds)], 1)

In [None]:
pca = PCA()
test_pca = np.array([32])#np.arange(5,38,5)#np.array([15,38])#
test_C = np.logspace(-3, 0, 41)
n_C = test_C.shape[0]
n_pca = test_pca.shape[0]

clf = LogisticRegression(penalty='l1', solver='liblinear', max_iter=5000, tol=1.0)
param_grid = {
    'pca__n_components': test_pca,
    'clf__C': test_C,
}

# clf = QuadraticDiscriminantAnalysis(tol=1e-3)
pipe = Pipeline(steps=[('pca', pca), ('clf', clf)])
# search = GridSearchCV(pipe, param_grid, cv=5, n_jobs=1, verbose=10)

In [None]:
search_l1_all = np.zeros(n_tc, dtype=object)
for i in range(n_tc):
    #X = Xct[i]**2
    # using also squared signal improve performances
    X = np.concatenate([Xct[i],Xct[i]**2], axis=1)

    X -= X.mean(0, keepdims=True)
    X /= X.std(0, keepdims=True)
    search_l1_all[i] = GridSearchCV(pipe, param_grid,
                                   cv=5, n_jobs=15, verbose=0)
    search_l1_all[i].fit(X, y);
    verbose('%i'%(i))

In [None]:
fig, ax = plt.subplots(1, n_tc, figsize=(12,3))
for i in range(n_tc):
    print("Best parameter (CV score=%0.3f):" % search_l1_all[i].best_score_)
    print(search_l1_all[i].best_params_)
    for k in range(test_pca.shape[0]):
        errorfill(test_C,search_l1_all[i].
                  cv_results_['mean_test_score']
                  .reshape(test_C.shape[0],test_pca.shape[0])[:,k],
                  search_l1_all[i].cv_results_['std_test_score']
                  .reshape(test_C.shape[0],test_pca.shape[0])[:,k]/np.sqrt(5),
                  ax=ax[i])

    ax[i].set_xscale('log')
    ax[i].set_ylim(0.45,0.85)

In [None]:
for i in range(n_tc):
    #print(np.abs(search_l1_all[i].best_estimator_['clf'].coef_))
    print(np.sum(np.abs(search_l1_all[i].best_estimator_['clf'].coef_)>0.5e-2))
    plt.plot(search_l1_all[i].best_estimator_['clf'].coef_.T)