In [1]:
import mne
import os
mne.set_log_level('ERROR')

import numpy as np
from copy import deepcopy
import time
from sklearn.pipeline import make_pipeline
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from pyriemann.tangentspace import TangentSpace
from pyriemann.estimation import XdawnCovariances
from sklearn.metrics import balanced_accuracy_score,f1_score,recall_score
from sklearn.model_selection import train_test_split
from imblearn.under_sampling import RandomUnderSampler

# Useful function

In [2]:
def balance(X,Y,domains):
    X_new = []
    Y_new = []
    domains_new = []
    if domains is not None:
        for d in np.unique(domains):
            ind_domain = np.where(domains==d)
            rus = RandomUnderSampler()
            counter=np.array(range(0,len(Y[ind_domain]))).reshape(-1,1)
            index,_ = rus.fit_resample(counter,Y[ind_domain])
            index = np.sort(index,axis=0)
            X_new.append(np.squeeze(X[ind_domain][index,:,:], axis=1))
            Y_new.append(np.squeeze(Y[ind_domain][index]))
            domains_new.append(np.squeeze(domains[ind_domain][index]))
        return np.concatenate(X_new),np.concatenate(Y_new),np.concatenate(domains_new)
    else:
        rus = RandomUnderSampler()
        counter=np.array(range(0,len(Y))).reshape(-1,1)
        index,_ = rus.fit_resample(counter,Y)
        index = np.sort(index,axis=0)
        X = np.squeeze(X[index,:,:], axis=1)
        Y = np.squeeze(Y[index])
        return X,Y,None

In [3]:
def make_preds_accumul_aggresive(y_pred, codes, min_len=30, sfreq=500, consecutive=30, window_size=0.25):
    length = int((2.2-window_size)*sfreq)
    y_pred = np.array(y_pred)
    rez_acc = []

    code_buffer = []
    labels_pred = []
    code_pos = 0
    y_tmp = [] 
    mean_long = []

    for trial in range(int(len(y_pred)/length)):   
        # Retrieve a trial
        tmp_code = y_pred[trial*length:(trial+1)*length]
        code_pos = 0

        # Do an average over the prdata, codes, labels, sfreq
        code_buffer = []
        for idx in range(len(tmp_code)):
            y_tmp.append(tmp_code[idx])
            if (idx+1)/sfreq >= (code_pos+1)/60:
                code_pred = np.mean(y_tmp) 
                code_pred = int(np.rint(code_pred))
                code_buffer.append(code_pred) 
                y_tmp = []
                code_pos += 1
        # Find the code that correlate the most
        corr = -2
        pred_lab = -1
        out = 0
        for long in np.arange(min_len, len(code_buffer) -1 , step=1):
            dtw_values = []
            for key, values in codes.items():
                dtw_values.append(np.corrcoef(code_buffer[:long], values[:long])[0,1])
            dtw_values = np.array(dtw_values)
            max_dtw = list(codes.keys())[np.argmax(dtw_values)] 
            if (max_dtw == pred_lab):
                out += 1
                corr = np.max(dtw_values)
            else:
                pred_lab = max_dtw
                out = 0
            if out == consecutive:
                mean_long.append((long)/60)
                break
        labels_pred.append(pred_lab)
    labels_pred = np.array(labels_pred)
    return labels_pred, code_buffer, mean_long

# Get the data

First you need to download the data and put it in the folder you want

In [15]:
# get the participant you want to look at
participants = ['P1','P2','P3','P4','P5','P6','P7','P8','P9','P10',
                'P11','P12','P13','P14','P15','P16','P17','P18','P19','P20',
                'P21','P22','P23','P24']


path = 'C:\\Users\\s.velut\\Documents\\These\\Protheus_PHD\\Data\\Dry_Ricker'
nb_subject = len(participants)

# General variables and hyperparameters
n_class=5
fmin = 1
fmax = 45
fps = 60
window_size = 0.35
sfreq = 500

# get a list of the data in mne format
raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
               for i in range(len(participants))]

  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(path, '_'.join([participants[i], 'dryburst100.set'])), preload=True, verbose=False).resample(sfreq=sfreq)
  raw_eeglab = [mne.io.read_raw_eeglab(os.path.join(

# Preprocessing

Perform preprocessing to have better signal as input of the classifier

In [16]:
for ind_i,i in enumerate(participants):

    ## Lines if you want to drop or keep channels
    # raw = raw.drop_channels([ch for ch in raw.ch_names if ch in to_drop])
    # raw = raw.drop_channels([i for i in raw.ch_names if i not in keep])

    raw_eeglab[ind_i] = raw_eeglab[ind_i].filter(l_freq=fmin, h_freq=fmax, method="fir", verbose=True)

    n_channels = len(raw_eeglab[ind_i].ch_names)
    print("Channels :", raw_eeglab[ind_i].ch_names)

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (3.302 s)



Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (3.302 s)

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (3.302 s)

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with

Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (3.302 s)

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (3.302 s)

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 45.00 Hz
- Upper transition bandwidth: 11.25 Hz (-6 dB cutoff frequency: 50.62 Hz)
- Filter length: 1651 samples (3.302 s)

Channels : ['EEG 000', 'EEG 001', 'EEG 002', 'EEG 003', 'EEG 004', 'EEG 005', 'EEG 006', 'EEG 007']
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 45 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished
[Parallel(n_jobs=1)]: Using backend SequentialBackend with

## Get the epochs and labels

In [53]:
def changeEventID(events, event_id):
    new_dic = {}
    for k in event_id.keys():
        new_dic[k.split('_')[1]] = event_id[k]
        event_id[k] = int(k.split('_')[1])
    for i in range(len(events)):
        events[i][2] = int(new_dic[str(events[i][2])])
    
    return events, event_id 


In [54]:
epochs_list = []
events_list = []
events_id_list = []
onset_code_list = []
data_list = []
labels_code_list = []
for ind_i, i in enumerate(participants): 
    # Strip the annotations that were script to make them easier to process
    events, event_id = mne.events_from_annotations(raw_eeglab[ind_i], event_id='auto', verbose=False)
    to_remove = []
    for idx in range(len(raw_eeglab[ind_i].annotations.description)):
        if (('boundary' in raw_eeglab[ind_i].annotations.description[idx]) or
            ('BURST' in raw_eeglab[ind_i].annotations.description[idx])):
            to_remove.append(idx)

    to_remove = np.array(to_remove)
    if len(to_remove) > 0:
        raw_eeglab[ind_i].annotations.delete(to_remove)

    # Get the events
    temp_event,temp_event_id = mne.events_from_annotations(raw_eeglab[ind_i], event_id='auto', verbose=False)
    temp_event,temp_event_id = changeEventID(temp_event,temp_event_id)
    events_list.append(temp_event)
    events_id_list.append(temp_event_id)
    shift = 0.0

    # Epoch the data following event
    epochs_list.append(mne.Epochs(raw_eeglab[ind_i], events_list[ind_i], event_id=events_id_list[ind_i], tmin=shift, \
                tmax=2.2+shift, baseline=(None, None), preload=False, verbose=False))

    # Add the data and the labels to the list
    labels_code_list.append(epochs_list[ind_i].events[..., -1])
    labels_code_list[ind_i] -= np.min(labels_code_list[ind_i])
    data_list.append(epochs_list[ind_i].get_data())
    info_ep = epochs_list[ind_i].info
    
    onset_code_list.append(epochs_list[ind_i].events[..., 0])
    
data_list = np.array(data_list)

In [55]:
from collections import OrderedDict

# Create the code to have all the order of the 0 and 1
codes = OrderedDict()
for k, v in events_id_list[0].items():
    code = k.split('_')[0]
    code = code.replace('.','').replace('2','')
    idx = k.split('_')[1]
    codes[v-1] = np.array(list(map(int, code)))

## Define usefull function to get good shape of the data

In [56]:
def to_window(data, labels,length,n_samples_windows,codes,window_size=0.25,normalise=True,sfreq=500,fps=60,n_channels=8):
    """
    transform epochs of 2.2 to epochs of duration window_size for each timestamp (one timestamp is 1/sfreq)
    """
    X = np.empty(shape=((length)*data.shape[0], n_channels, n_samples_windows))
    idx_taken = []
    y = np.empty(shape=((length)*data.shape[0]), dtype=int)
    count = 0
    for trial_nb, trial in enumerate(data):
        lab = labels[trial_nb]
        c = codes[lab]
        code_pos = 0
        for idx in range(length):
            X[count] = trial[:, idx:idx+n_samples_windows]
            if idx/sfreq >= (code_pos+1)/fps:
                code_pos += 1 
            y[count] = int(c[code_pos])
            count += 1
        
        # take index taken to be able to use for onset annotations
        for idx in range(length):
            idx_taken.append(trial_nb*length+idx)

    y_pred = np.vstack((y,np.abs(1-y))).T
    y = np.array([1 if (y >= 0.5) else 0 for y in y_pred[:,0]])

    return X, y, np.array(idx_taken)

def onset_anno(onset_window,label_window,onset_code,nb_seq_min,nb_seq_max,code_freq,sfreq,win_size):
    """
    Create the onset for the annotation of the raw_data to be able to use the raw_data more easily
    """
    assert(sfreq!=0)
    new_onset = []
    new_onset_0 = []
    current_code = 0
    onset_code = np.ceil(onset_code*code_freq/sfreq)
    nb_seq_min-=1
    onset_shift = onset_code[current_code+nb_seq_min]
    time_trial = (2.2-win_size)
    # onset_window = np.arange(0,time_trial*code_freq*(nb_seq_max-nb_seq_min)-1,1,dtype=int)
    for i,o in enumerate(onset_window):
        if label_window[i]==1:
            # print(i)
            if current_code==nb_seq_max-1-nb_seq_min:
                new_onset.append(o+onset_shift)
            else:
                if o+onset_shift >= onset_code[current_code+nb_seq_min]+time_trial*code_freq:
                    current_code+=1
                    onset_shift = onset_code[current_code+nb_seq_min]-time_trial*code_freq*current_code
                new_onset.append(o+onset_shift)
        else:
            if current_code==nb_seq_max-1-nb_seq_min:
                new_onset_0.append(o+onset_shift)
            else:
                if o+onset_shift >= onset_code[current_code+nb_seq_min]+time_trial*code_freq:
                    current_code+=1
                    onset_shift = onset_code[current_code+nb_seq_min]-time_trial*code_freq*current_code
                new_onset_0.append(o+onset_shift)
    
    # modified_onset_code = [onset_code[i]-time_trial*sfreq*i for i in range(nb_seq_min,nb_seq_max)]
    # new_onset_0 = np.concatenate([np.arange(onset_code[i],onset_code[i]+time_trial*sfreq,sfreq//60) for i in range(nb_seq_min,nb_seq_max)])
    new_onset_0 = np.array(list(filter(lambda i: i not in new_onset, new_onset_0)))
    # print(new_onset_0.shape)
    return np.array(new_onset)/code_freq, np.array(new_onset_0)/code_freq
            

In [61]:

n_samples_windows = int(window_size*sfreq)
length = int((2.2-window_size)*sfreq)
# length = int((2.2-window_size)*fps)

X_parent = np.zeros((data_list.shape[0],length*data_list.shape[1],data_list.shape[2],n_samples_windows))
Y_parent = np.zeros((data_list.shape[0],length*data_list.shape[1]))
idx_taken = np.zeros((data_list.shape[0],length*data_list.shape[1]))
domains_parent = []
for ind_i,i in enumerate(participants):
    # Get the epoch for small windows for each timestamp (sfreq=sfreq for each 1/500s and sfreq=fps for each frame)
    X_parent[ind_i],Y_parent[ind_i],idx_taken[ind_i] = to_window(data_list[ind_i],labels_code_list[ind_i],length,n_samples_windows,codes,window_size=window_size,sfreq=sfreq)
    domains_parent.append(["Source_sub_{}".format(ind_i+1),]*len(Y_parent[ind_i]))
domains_parent = np.array(domains_parent)
    

In [58]:

for ind_i,i in enumerate(participants):
    onset,onset_0 = onset_anno(idx_taken[ind_i],Y[ind_i],onset_code_list[ind_i],1,n_class*15,sfreq,sfreq,window_size)
    anno = mne.Annotations(onset,1/sfreq,"1")
    anno.append(onset_0,1/sfreq,"0")

    raw_eeglab[ind_i] = raw_eeglab[ind_i].set_annotations(anno)

X_parent = np.zeros((X.shape[0],length*data_list.shape[1],data_list.shape[2],int(window_size*sfreq)))
domains_parent = []
Y_parent = np.zeros((data_list.shape[0],length*data_list.shape[1]))
for ind_i,i in enumerate(participants):
    events, event_id = mne.events_from_annotations(raw_eeglab[ind_i])
    print(events.shape)
    epochs = mne.Epochs(raw_eeglab[ind_i],events,event_id,tmin=0.0,tmax=window_size,baseline=(0,0))
    X_parent[ind_i] = epochs.get_data()[:,:,:-1]
    Y_parent[ind_i] = epochs.events[...,-1]-1
    
    # Get domains per subject for the balancing later
    domains_parent.append(["Source_sub_{}".format(ind_i+1),]*len(Y_parent[ind_i]))

domains_parent = np.array(domains_parent)

(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)
(69375, 3)


# Create classifier

In [62]:
# XdawnCovariance to go in riemannian space + Tangent Space operation + LDA
model =  make_pipeline(XdawnCovariances(nfilter=8, estimator="lwf", xdawn_estimator="lwf",classes=[1]),
                        TangentSpace(), LDA(solver="lsqr", shrinkage="auto"))

# Train/test

In [63]:
n_cal = 7
n_class = 5
nb_fold = 1

# list to retrieve accuracy of code classification
spdbn_accuracy_code_perso = np.zeros((nb_fold,nb_subject))
# list to retrieve time of training
spdbn_tps_train_code_perso = np.zeros((nb_fold,nb_subject))
# list to retrieve time of testing
spdbn_tps_test_code_perso = np.zeros((nb_fold,nb_subject))
# list to retrieve time of accumulation prediction
spdbn_tps_pred_code_perso = np.zeros((nb_fold,nb_subject))
# list to retrieve accuracy of bits classification (0 and 1)
spdbn_accuracy_perso = np.zeros((nb_fold,nb_subject))
# list to retrieve recall of bits classification (0 and 1)
spdbn_recall_perso = np.zeros((nb_fold,nb_subject))
# list to retrieve f1 score of bits classification (0 and 1)
spdbn_f1_perso = np.zeros((nb_fold,nb_subject))

for k in range(nb_fold):
    for i in range(nb_subject):
        print("TL to the participant : ", i)
        X = X_parent.copy()
        Y = Y_parent.copy()
        domains = domains_parent.copy()

        # #preprocess for Domain adaptation (train on n-1 subject and a part of the nth subject and test on the rest of the data of the nth subject)
        # if you need to do some preprocessing that need information of the label or the data (recentering, transform in covariances...)
        # For example here we are normalizing the data
        for j in range(nb_subject):
            if j==i:
                X_std = X[j][:nb_sample_cal].std(axis=0)
                X[j] = X[j]/(X_std + 1e-8)
            else:
                X_std = X[j].std(axis=0)
                X[j] = X[j]/(X_std + 1e-8)
        
        # Domain Adaptation
        n_cal = 4
        nb_sample_cal = int(n_class*n_cal*(2.2-window_size)*sfreq)
        ind2take = [j for j in range(len(participants)) if j!=i]
        ind2take = [j for j in range(6) if j!=i]
        X_train = np.concatenate([np.concatenate(X[ind2take]).reshape(-1,X.shape[-2],X.shape[-1]),X[i][:nb_sample_cal]]).reshape(-1,X.shape[-2],X.shape[-1])
        Y_train = np.concatenate([np.concatenate(Y[ind2take]).reshape(-1),Y[i][:nb_sample_cal]]).reshape(-1)
        X_test = X[i][nb_sample_cal:]
        Y_test = Y[i][nb_sample_cal:]
        labels_code_test = labels_code_list[i][(n_class*4):]
        domains_train = np.concatenate([np.concatenate(domains[ind2take]).reshape(-1),domains[i][:nb_sample_cal]]).reshape(-1)
        domain_test = domains[i][nb_sample_cal]
                    
        

        # ind2take = [j for j in range(len(participants)) if j!=i]
        # X_train = np.concatenate(X[ind2take]).reshape(-1,X.shape[-2],X.shape[-1])
        # Y_train = np.concatenate(Y[ind2take]).reshape(-1)
        # X_test = X[i]
        # Y_test = Y[i]
        # labels_code_test = labels_code_list[i]
        # domains_train = np.concatenate(domains[ind2take]).reshape(-1)
        # ## preprocess for Domain Generalisation (train on n-1 subject  and test on the nth subject)
        # # if you need to do some preprocessing that need information of the label or the data (recentering, transform in covariances...)
        # # For example here we are normalizing the data
        # X_std = X_train.std(axis=0)
        # X_train /= X_std + 1e-8
        # X_test /= X_std + 1e-8
        

        # # preprocess for classical train/test classification
        # # if you need to do some preprocessing that need information of the label or the data (recentering, transform in covariances...)
        # # For example here we are normalizing the data
        # nb_sample_cal = int(n_class*n_cal*(2.2-window_size)*sfreq)
        # X_std = X[i][:nb_sample_cal].std(axis=0)
        # X[i] /= X_std + 1e-8

        # # Train/test
        # X_train = X[i][:nb_sample_cal]
        # Y_train = Y[i][:nb_sample_cal]
        # X_test = X[i][nb_sample_cal:]
        # Y_test = Y[i][nb_sample_cal:]
        # domains_train = domains[i][:nb_sample_cal]
        # labels_code_test = labels_code_list[i][(n_class*n_cal):]


        #Balancing the number of 0 and 1 as there are around 4 times more 0 than 1
        print("balancing the number of ones and zeros")
        X_train, Y_train, domains_train = balance(X_train,Y_train,domains_train)
        print(X_train.shape)
        print(Y_train.shape)
        print(X_test.shape)

        print("Creating the different pipelines")
        lr = 1e-3
        batchsize = 64 #128 # 64 for burst
        clf = deepcopy(model)
        
        x_train, x_val, y_train, y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=42, shuffle=True)

        print("Fitting")
        start = time.time()
        lr = 1e-3
        clf.fit(np.array(x_train), y_train)

        spdbn_tps_train_code_perso[k][i] = time.time() - start

        print("getting accuracy of participant ", i)
        start = time.time()
        y_pred = clf.predict(X_test)
        y_pred = np.array(y_pred)
        y_pred_norm = np.array([1 if (y >= 0.5) else 0 for y in y_pred])
        y_test_norm = np.array([0 if y == 0 else 1 for y in Y_test])

        spdbn_accuracy_perso[k][i] = balanced_accuracy_score(y_test_norm,y_pred_norm)
        spdbn_recall_perso[k][i] = recall_score(y_test_norm,y_pred_norm)
        spdbn_f1_perso[k][i] = f1_score(y_test_norm,y_pred_norm)
        print(f"Test Accuracy: {spdbn_accuracy_perso[k][i]}")
        print(f"Test recall: {spdbn_recall_perso[k][i]}")
        print(f"Test f1: {spdbn_f1_perso[k][i]}")

        labels_pred_accumul, _, mean_long_accumul = make_preds_accumul_aggresive(
            y_pred_norm, codes, min_len=30, sfreq=sfreq, consecutive=50, window_size=window_size
        )
        spdbn_tps_test_code_perso[k][i] = time.time() - start
        spdbn_accuracy_code_perso[k][i] = np.round(balanced_accuracy_score(labels_code_test[labels_pred_accumul!=-1], labels_pred_accumul[labels_pred_accumul!=-1]), 2)
        spdbn_tps_pred_code_perso[k][i] = np.mean(mean_long_accumul)
        # keras.backend.clear_session()

spdbn_accuracy_perso = np.mean(spdbn_accuracy_perso,axis=0)
spdbn_tps_train_code_perso = np.mean(spdbn_tps_train_code_perso,axis=0)
spdbn_tps_test_code_perso = np.mean(spdbn_tps_test_code_perso,axis=0)
spdbn_accuracy_code_perso = np.mean(spdbn_accuracy_code_perso,axis=0)

#print mean of the different measure collected
print(spdbn_accuracy_perso)
print(spdbn_recall_perso)
print(spdbn_f1_perso)
print(spdbn_tps_train_code_perso)
print(spdbn_tps_test_code_perso)
print(np.mean(spdbn_tps_pred_code_perso,axis=0))
print(spdbn_accuracy_code_perso)

#you can save the measure collected after

TL to the participant :  0
balancing the number of ones and zeros
(34128, 8, 175)
(34128,)
(50875, 8, 175)
Creating the different pipelines
Fitting
getting accuracy of participant  0
Test Accuracy: 0.7587508963529603
Test recall: 0.7664141414141414
Test f1: 0.22386133136640235
TL to the participant :  1
balancing the number of ones and zeros
(34128, 8, 175)
(34128,)
(50875, 8, 175)
Creating the different pipelines
Fitting
getting accuracy of participant  1
Test Accuracy: 0.8107610848028176
Test recall: 0.7693602693602694
Test f1: 0.32143485141550904


  c /= stddev[:, None]
  c /= stddev[None, :]


TL to the participant :  2
balancing the number of ones and zeros
(34128, 8, 175)
(34128,)
(50875, 8, 175)
Creating the different pipelines
Fitting
getting accuracy of participant  2
Test Accuracy: 0.7297851883925823
Test recall: 0.7239057239057239
Test f1: 0.20335776779380468
TL to the participant :  3
balancing the number of ones and zeros
(34128, 8, 175)
(34128,)
(50875, 8, 175)
Creating the different pipelines
Fitting
getting accuracy of participant  3
Test Accuracy: 0.7562502720560103
Test recall: 0.726010101010101
Test f1: 0.23865522966242392
TL to the participant :  4
balancing the number of ones and zeros
(34128, 8, 175)
(34128,)
(50875, 8, 175)
Creating the different pipelines
Fitting
getting accuracy of participant  4
Test Accuracy: 0.677111135548423
Test recall: 0.6582491582491582
Test f1: 0.1674070109713674
TL to the participant :  5
balancing the number of ones and zeros
(34128, 8, 175)
(34128,)
(50875, 8, 175)
Creating the different pipelines
Fitting
getting accuracy of p

In [64]:
np.mean(spdbn_accuracy_code_perso)

0.7808333333333334