# Physionet Dataset

https://archive.physionet.org/pn4/eegmmidb/

In summary, the experimental runs were:
(list of number according by folder)

- 1. Baseline, eyes open
- 2. Baseline, eyes closed
- 3. Task 1 (open and close left or right fist)
- 4. Task 2 (imagine opening and closing left or right fist)
- 5. Task 3 (open and close both fists or both feet)
- 6. Task 4 (imagine opening and closing both fists or both feet)
- 7. Task 1
- 8. Task 2
- 9. Task 3
- 10. Task 4
- 11. Task 1
- 12. Task 2
- 13. Task 3
- 14. Task 4

In [62]:
import mne
import numpy as np
from mne.datasets import eegbci
import matplotlib.pyplot as plt
from os import listdir
from mne.channels import make_standard_montage
from scipy import signal
from scipy.linalg import sqrtm, inv 
from sklearn.model_selection import train_test_split, GridSearchCV, KFold
from sklearn.utils import shuffle
from mne.decoding import CSP
from sklearn.svm import SVC
from sklearn.manifold import TSNE
import seaborn as sns
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import ShuffleSplit,StratifiedKFold ,cross_val_score, cross_val_predict, KFold
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.neighbors import KNeighborsClassifier
from pyriemann.utils.distance import distance_riemann
from scipy.linalg import logm, expm
import random
random.seed(42)

target_class = ["Left", "Right", "Non"] 
target_data_0 = "P007"
calibrate_size = 30
alignmentMethod = "EA"

con_for_cnn = "EA"
train_svm = False

condition_wLTL = "EA"

In [63]:
selected_ch = ['Fz..','C3..', 'Cz..','C4..','Pz..']

num_subject = 30

raw_RorL1sub = [0]*num_subject  # RorL is Right or Left fist movement/imagery
raw_Both1sub = [0]*num_subject  # Both is both feet and both fits movement/imagery

start_subject = 0

RAW_data_RorL = {}
RAW_data_Both = {}

for j in range(1+start_subject,num_subject+1+start_subject):

    print("processing subject number: ", j)

    if j < 10:
        subject  = '00' + str(j)
    elif j < 100:
        subject = '0' + str(j)
    else:
        subject = str(j)

    raw_RorL1 = mne.io.read_raw_edf("D:\physionet_dataset\S" + str(subject) +"\S" + str(subject) +"R04.edf",preload = True, verbose=False)
    raw_RorL2 = mne.io.read_raw_edf("D:\physionet_dataset\S" + str(subject) +"\S" + str(subject) +"R08.edf",preload = True, verbose=False)
    raw_RorL3 = mne.io.read_raw_edf("D:\physionet_dataset\S" + str(subject) +"\S" + str(subject) +"R12.edf",preload = True, verbose=False)

    raw_RorL1sub[j-1+start_subject] = mne.concatenate_raws([raw_RorL1.pick(selected_ch), raw_RorL2.pick(selected_ch), raw_RorL3.pick(selected_ch)])
    raw_RorL1sub[j-1+start_subject] = raw_RorL1sub[j-1+start_subject].resample(128) 

    eegbci.standardize(raw_RorL1sub[j-1+start_subject])  # set channel names
    montage = make_standard_montage("standard_1005")    
    raw_RorL1sub[j-1+start_subject].set_montage(montage)

    RAW_data_RorL["P" + str(subject)] = {"Raw_data": raw_RorL1sub[j-1+start_subject].copy()}

    raw_Both1 = mne.io.read_raw_edf("D:\physionet_dataset\S" + str(subject) +"\S" + str(subject) +"R06.edf",preload = True, verbose=False)
    raw_Both2 = mne.io.read_raw_edf("D:\physionet_dataset\S" + str(subject) +"\S" + str(subject) +"R10.edf",preload = True, verbose=False)
    raw_Both3 = mne.io.read_raw_edf("D:\physionet_dataset\S" + str(subject) +"\S" + str(subject) +"R14.edf",preload = True, verbose=False)

    raw_Both1sub[j-1+start_subject] = mne.concatenate_raws([raw_Both1.pick(selected_ch), raw_Both2.pick(selected_ch), raw_Both3.pick(selected_ch)])
    raw_Both1sub[j-1+start_subject] = raw_Both1sub[j-1+start_subject].resample(128)

    eegbci.standardize(raw_Both1sub[j-1+start_subject])  # set channel names
    montage = make_standard_montage("standard_1005")    
    raw_Both1sub[j-1+start_subject].set_montage(montage)

    RAW_data_Both["P" + str(subject)] = {"Raw_data": raw_Both1sub[j-1+start_subject].copy()}

processing subject number:  1
processing subject number:  2
processing subject number:  3
processing subject number:  4
processing subject number:  5
processing subject number:  6
processing subject number:  7
processing subject number:  8
processing subject number:  9
processing subject number:  10
processing subject number:  11
processing subject number:  12
processing subject number:  13
processing subject number:  14
processing subject number:  15
processing subject number:  16
processing subject number:  17
processing subject number:  18
processing subject number:  19
processing subject number:  20
processing subject number:  21
processing subject number:  22
processing subject number:  23
processing subject number:  24
processing subject number:  25
processing subject number:  26
processing subject number:  27
processing subject number:  28
processing subject number:  29
processing subject number:  30


In [64]:
def butter_bandpass(lowcut,highcut,fs,order):
    nyq = 0.5*fs
    low = lowcut/nyq
    high = highcut/nyq
    b,a = signal.butter(order,[low,high],'bandpass')
    return b,a

def butter_bandpass_filter(data,lowcut = 6,highcut = 30, order = 4):
    b,a = butter_bandpass(lowcut,highcut,128,order)
    y = signal.filtfilt(b,a,data,axis=2)
    return y

def Get_epoch(RAW_data_RorL, RAW_data_Both, tmin=-2.0, tmax=4.0, crop=(0,2),baseline = (-0.5,0.0), trial_removal_th = 100):
    EEG_epoch = {}
    for key_subs in RAW_data_RorL:

        event_id_mapping_RorL = {
            old_event_id: new_event_id
            for old_event_id, new_event_id in zip([1, 2, 3], [2, 0, 1])  # Example mapping
        }

        events_RorL, event_id1 = mne.events_from_annotations(RAW_data_RorL[key_subs]['Raw_data'], verbose=False)
        events_RorL[:, 2] = [event_id_mapping_RorL.get(event_id1, event_id1) for event_id1 in events_RorL[:, 2]]

        event_id1 = {'Rest': 2, 'Left': 0, 'Right': 1 } 

        RorL_epochs = mne.Epochs(RAW_data_RorL[key_subs]['Raw_data'], events_RorL, 
            tmin= tmin,     # init timestamp of epoch (0 means trigger timestamp same as event start)
            tmax= tmax,    # final timestamp (10 means set epoch duration 10 second)
            event_id= event_id1,
            preload = True,
            event_repeated='drop',
            baseline=baseline,
            verbose = False
        )

        ########################################################################################################

        event_id_mapping_Both = {
            old_event_id: new_event_id
            for old_event_id, new_event_id in zip([1, 2, 3], [2, 4, 3])  # Example mapping
        }

        events_Both, event_id2 = mne.events_from_annotations(RAW_data_Both[key_subs]['Raw_data'], verbose=False)
        events_Both[:, 2] = [event_id_mapping_Both.get(event_id2, event_id2) for event_id2 in events_Both[:, 2]]

        event_id2 = {'Rest': 2,'both_feet': 3}  # Don't use both fits

        Both_epochs = mne.Epochs(RAW_data_Both[key_subs]['Raw_data'], events_Both, 
            tmin= tmin,     # init timestamp of epoch (0 means trigger timestamp same as event start)
            tmax= tmax,    # final timestamp (10 means set epoch duration 10 second)
            event_id= event_id2,
            preload = True,
            event_repeated='drop',
            baseline=baseline,
            verbose = False
        )

        combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])

        EEG_epoch[key_subs] =  {"Raw_Epoch": combine_epoch.copy().crop(tmin= crop[0], tmax= crop[1])}

        train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
        labels = EEG_epoch[key_subs]["Raw_Epoch"].copy().events[:,-1]

        outlier_trial = []
        for ii in range(0,train_data.shape[0]):
            if train_data[ii].max() > trial_removal_th or train_data[ii].min() < -trial_removal_th:
                outlier_trial.append(ii)
                print(key_subs,train_data[ii].min(), ii)
                print(key_subs,train_data[ii].max(), ii)

        EEG_epoch[key_subs]['Raw_Epoch'] = np.delete(train_data, outlier_trial, axis = 0)
        EEG_epoch[key_subs]['label'] = np.delete(labels, outlier_trial)

        filtered_data = butter_bandpass_filter(EEG_epoch[key_subs]['Raw_Epoch'], lowcut= 6, highcut= 32)
        # EEG_epoch[key_subs]['Raw_Epoch'] = filtered_data

        random_delete = random.sample(list(np.where(EEG_epoch[key_subs]['label']== 2)[0]), 60)

        EEG_epoch[key_subs]['Raw_Epoch'] = np.delete(filtered_data, random_delete, axis = 0)
        EEG_epoch[key_subs]['label'] = np.delete(EEG_epoch[key_subs]['label'], random_delete)

        if "Left" not in target_class:
                EEG_epoch[key_subs]['Raw_Epoch'] = np.delete(EEG_epoch[key_subs]['Raw_Epoch'], np.where(EEG_epoch[key_subs]['label']== 0), axis = 0)
                EEG_epoch[key_subs]['label'] = np.delete(EEG_epoch[key_subs]['label'], np.where(EEG_epoch[key_subs]['label']== 0))

        if "Right" not in target_class:
            EEG_epoch[key_subs]['Raw_Epoch'] = np.delete(EEG_epoch[key_subs]['Raw_Epoch'], np.where(EEG_epoch[key_subs]['label']== 1), axis = 0)
            EEG_epoch[key_subs]['label'] = np.delete(EEG_epoch[key_subs]['label'], np.where(EEG_epoch[key_subs]['label']== 1))

        if "Non" not in target_class:
            EEG_epoch[key_subs]['Raw_Epoch'] = np.delete(EEG_epoch[key_subs]['Raw_Epoch'], np.where(EEG_epoch[key_subs]['label']== 2), axis = 0)
            EEG_epoch[key_subs]['label'] = np.delete(EEG_epoch[key_subs]['label'], np.where(EEG_epoch[key_subs]['label']== 2))

        if "Feet" not in target_class:
            EEG_epoch[key_subs]['Raw_Epoch'] = np.delete(EEG_epoch[key_subs]['Raw_Epoch'], np.where(EEG_epoch[key_subs]['label']== 3), axis = 0)
            EEG_epoch[key_subs]['label'] = np.delete(EEG_epoch[key_subs]['label'], np.where(EEG_epoch[key_subs]['label']== 3))
        
    return EEG_epoch

EEG_Epochs = Get_epoch(RAW_data_RorL, RAW_data_Both, tmin=-2.0, tmax=4.0, crop=(0,4),baseline = (-0.5,0.0), trial_removal_th = 1000)


Not setting metadata
153 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
153 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)


  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5


Not setting metadata
151 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
151 matching events found


  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])


Applying baseline correction (mode: mean)
Not setting metadata
151 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
151 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)


  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5


Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)


  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5


Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata


  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])


152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
151 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
153 matching events found
Applying baseline correction (mode: mean)


  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5


Not setting metadata
151 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
153 matching events found
Applying baseline correction (mode: mean)


  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5


Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
151 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
152 matching events found
Applying baseline correction (mode: mean)
Not setting metadata
150 matching events found
Applying baseline correction (mode: mean)


  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5
  combine_epoch = mne.concatenate_epochs([RorL_epochs, Both_epochs])
  train_data = EEG_epoch[key_subs]['Raw_Epoch'].copy().get_data() * 10e5


In [65]:
del RAW_data_Both
del RAW_data_RorL

In [66]:
EEG_Epochs['P001']['label']

array([1, 0, 0, 1, 1, 0, 1, 0, 1, 2, 0, 0, 1, 0, 2, 1, 0, 0, 1, 2, 0, 1,
       0, 1, 2, 0, 2, 1, 1, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 2, 0, 0, 1, 2,
       1, 0, 0, 2, 1, 1, 0, 2, 1, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2])

In [67]:
def GetConfusionMatrix(models, X_train, X_test, y_train, y_test):
    y_pred = models.predict(X_train)
    print("Classification TRAIN DATA \n=======================")
    print(classification_report(y_true= y_train, y_pred=y_pred))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true= y_train, y_pred=y_pred))

    y_pred = models.predict(X_test)
    print("Classification TEST DATA \n=======================")
    print(classification_report(y_true=y_test, y_pred=y_pred))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true=y_test, y_pred=y_pred))
    

label_target = EEG_Epochs[target_data_0]['label']
x_train, x_test, y_train, y_test = train_test_split(EEG_Epochs[target_data_0]['Raw_Epoch'], label_target, test_size=0.3, random_state = 42, stratify=label_target)

csp = CSP(n_components = 5, reg=None, log=None, rank= 'info')
csp.fit(x_train, y_train)   

x_train = csp.transform(x_train)
x_test = csp.transform(x_test)

lda = LinearDiscriminantAnalysis()
score = cross_val_score(lda, x_train, y_train, cv= 5)
print("LDA only Cross-validation scores:", np.mean(score))
lda.fit(x_train, y_train)

GetConfusionMatrix(lda, x_train, x_test, y_train, y_test)

Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
LDA only Cross-validation scores: 0.7511111111111111
Classification TRAIN DATA 
              precision    recall  f1-score   support

           0       0.86      0.75      0.80        16
           1       0.93      0.87      0.90        15
           2       0.70      0.82      0.76        17

    accuracy                           0.81        48
   macro avg       0.83      0.81      0.82        48
weighted avg       0.82      0.81      0.81        48

Confusion matrix 

In [68]:
from BCIAllFunction import BCIFuntions

calibrate_size = calibrate_size / EEG_Epochs[target_data_0]['Raw_Epoch'].shape[0]
AllBCIClass = BCIFuntions(numclass = 4, frequency = 128, ch_pick = ['Fz','C3', 'Cz','C4','Pz'])

if alignmentMethod == "LA":
    AllBCIClass.ComputeLA(EEG_Epochs, target_subject= target_data_0, calibrate_size=calibrate_size)
    if calibrate_size != 0:
        target_data = target_data_0 + "_test"
    count = 0

    for index in range(len(EEG_Epochs[target_data]['KMediod_label'])):
        if EEG_Epochs[target_data]['label'][index] == EEG_Epochs[target_data]['KMediod_label'][index]:
            count += 1
    print(count/len(EEG_Epochs[target_data]['KMediod_label']) * 100)
    
else:
    AllBCIClass.GetRawSet_ComputeEA(EEG_Epochs, target_subject= target_data_0, calibrate_size=calibrate_size)
    if calibrate_size != 0:
        target_data = target_data_0 + "_test"

In [69]:
EEG_Epochs[target_data_0].keys()

dict_keys(['Raw_Epoch', 'label', 'Raw_left', 'Raw_right', 'Raw_non', 'Raw_feet', 'EA_left', 'EA_right', 'EA_feet', 'EA_non', 'EA_Epoch'])

In [70]:
AllBCIClass.classifyCSP_LDA(EEG_Epochs, target_subjects= target_data, condition = "noEA")

Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
LDA only Cross-validation scores: 0.41949676422293053
Classification TRAIN DATA 
              precision    recall  f1-score   support

           0       0.41      0.35      0.38       670
           1       0.43      0.40      0.41       655
           2       0.44      0.54      0.49       706

    accuracy                           0.43      2031
   macro avg       0.43      0.43      0.43      2031
weighted avg       0.43      0.43      0.43      2031

Confusion matrix

In [71]:
AllBCIClass.classifyCSP_LDA(EEG_Epochs, target_subjects= target_data, condition = "EA") 

Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
LDA only Cross-validation scores: 0.510069545059403
Classification TRAIN DATA 
              precision    recall  f1-score   support

           0       0.49      0.48      0.48       670
           1       0.49      0.46      0.48       655
           2       0.56      0.60      0.58       706

    accuracy                           0.52      2031
   macro avg       0.52      0.52      0.52      2031
weighted avg       0.52      0.52      0.52      2031

Confusion matrix 


In [72]:
if train_svm == True:
    AllBCIClass.classifyCSP_SVM(EEG_Epochs, target_subjects= target_data, condition = "noEA")

In [73]:
if train_svm == True:
    AllBCIClass.classifyCSP_SVM(EEG_Epochs, target_subjects= target_data, condition = "EA")

In [74]:
# from sklearn.metrics import classification_report,confusion_matrix, accuracy_score
# from sklearn.model_selection import train_test_split, GridSearchCV, KFold

# def discriminative_frequency_band_selection(data, true_label, label_1, label_2, x_test, y_test):
#     Bh, Bl = 32, 6  # Initial upper and lower frequency limits
#     A1, A2 = 0, 0  # Initial accuracy values

#     # Step 1: Finding Bh
#     while A1 >= A2:
#         Bh -= 2
#         A2 = A1        
#         if(Bl >= Bh):
#             break
#         A1 = train_classifier(Bh, Bl, data, label_1, label_2, true_label, x_test, y_test)  # Design filter and train classifier using CSP

#     Bh += 1  # Adjust Bh back
#     A1 = train_classifier(Bh, Bl, data, label_1, label_2, true_label, x_test, y_test)

#     if A1 >= A2:
#         A2 = A1
#     else:
#         Bh = Bh + 1 
#         A1 = A2

#     # Step 2: Finding Bl
#     while A1 >= A2:
#         Bl += 2
#         A2 = A1
#         if(Bl >= Bh):
#             break
#         A1 = train_classifier(Bh, Bl, data, label_1, label_2, true_label, x_test, y_test)  # Design filter and train classifier using CSP

#     Bl -= 1  # Adjust Bl back
#     A1 = train_classifier(Bh, Bl, data, label_1, label_2, true_label, x_test, y_test)

#     if A1 >= A2:
#         A2 = A1
#     else:
#         Bh = Bh + 1 
#         A1 = A2

#     return Bh, Bl, A1


# def train_classifier(Bh, Bl, data, label_1, label_2, true_label, x_test, y_test):

#     init_data = data[np.where((true_label == label_1) | (true_label == label_2))]
#     init_label = true_label[np.where((true_label == label_1) | (true_label == label_2))]

#     print(init_data.shape, init_label.shape)

#     x_test = x_test[np.where((y_test == label_1) | (y_test == label_2))]
#     y_test = y_test[np.where((y_test == label_1) | (y_test == label_2))]

#     filtered_data = butter_bandpass_filter(init_data, lowcut= Bl, highcut= Bh)
#     x_test = butter_bandpass_filter(x_test, lowcut= Bl, highcut= Bh)

#     csp = CSP(n_components = 5, reg=None, log=None, rank= 'info')
#     csp.fit(filtered_data, init_label)

#     x_train = csp.transform(filtered_data)
#     x_test = csp.transform(x_test)

#     # Initialize SVM with a linear kernel
#     clf = SVC()

#     param_grid = {
#         'C':[1],
#         'kernel': ['rbf'],  # Example kernels
#     }

#     grid_search = GridSearchCV(clf, param_grid, cv=5, scoring='accuracy')

#     grid_search.fit(x_train, init_label)
#     y_pred = grid_search.predict(x_test)

#     return accuracy_score(y_test, y_pred)


# def get_frequency_band(EEG_Epochs, target_data, condition = "noEA"):

#     train_data = None
#     train_label = None
#     test_data = None
#     test_label = None

#     if condition == "noEA":
#         query = "Raw_Epoch"
#     else:
#         query = "EA_Epoch"

#     for sub in EEG_Epochs.keys():
#         if sub == target_data:
#             test_data = EEG_Epochs[sub][query]
#             test_label = EEG_Epochs[sub]['label']

#         else:
#             if train_data is None:
#                 train_data = EEG_Epochs[sub][query]
#             else:
#                 train_data = np.concatenate((train_data, EEG_Epochs[sub][query]), axis=0)

#             if train_label is None:
#                 train_label = EEG_Epochs[sub]['label']
#             else:
#                 train_label = np.concatenate((train_label, EEG_Epochs[sub]['label']), axis=0)

#     indices = [0, 1, 2, 3]
#     pairs = []
#     band_high = []
#     band_low = []
#     acc= []
#     class_name = ['left', 'right', 'non', 'feet']

#     train_data, x_test_temp, train_label, y_test_temp = train_test_split(train_data, train_label, test_size=0.3, random_state = 42, stratify=train_label)

#     # Nested loop to generate all pairs without reversing and without self-pairing
#     for i in range(len(indices)):
#         for j in range(i + 1, len(indices)):
#             pairs.append((indices[i], indices[j]))
#             Bh, Bl, A1 = discriminative_frequency_band_selection(data=train_data, true_label=train_label, label_1=indices[i], label_2=indices[j], x_test=x_test_temp, y_test=y_test_temp)
#             band_high.append(Bh)
#             band_low.append(Bl)
#             acc.append(A1)

#     stack_csp_train = []
#     stack_csp_test = []
#     stack_cnn_train = []
#     stack_cnn_test = []

#     for j in range(0,1):
#         stack_csp_train = []
#         stack_csp_test = []
#         for i in range(0,len(band_high)):
            
#             # filter_x_train = butter_bandpass_filter(train_data[:,:,0:256+(128*j)] ,lowcut= band_low[i], highcut=band_high[i])
#             # filter_x_test = butter_bandpass_filter(test_data[:,:,0:256+(128*j)] ,lowcut= band_low[i], highcut=band_high[i])

#             filter_x_train = butter_bandpass_filter(train_data ,lowcut= band_low[i], highcut=band_high[i])
#             filter_x_test = butter_bandpass_filter(test_data ,lowcut= band_low[i], highcut=band_high[i])

#             csp = CSP(n_components = 5, reg=None, log=None, rank= 'info')
#             csp.fit(filter_x_train, train_label)

#             filter_x_train = csp.transform(filter_x_train)
#             filter_x_test = csp.transform(filter_x_test)

#             stack_csp_train.append(filter_x_train)
#             stack_csp_test.append(filter_x_test)

#         stack_cnn_train.append(np.hstack(np.array(stack_csp_train)))
#         stack_cnn_test.append(np.hstack(np.array(stack_csp_test)))

#     for i in range(len(indices)):
#         for j in range(i + 1, len(indices)):
#             print(f"The selected frequency band of {class_name[i]} vs {class_name[j]} is: Bl = {band_low[j-1+i]}, Bh = {band_high[j-1+i]}, Acc = {acc[j-1+i]}" )

#     train_for_cnn = np.transpose(stack_cnn_train, (1, 2, 0))
#     test_for_cnn = np.transpose(stack_cnn_test, (1, 2, 0))

#     return train_for_cnn, test_for_cnn, train_label, test_label

# train_for_cnn, test_for_cnn, train_label, test_label = get_frequency_band(EEG_Epochs, target_data, condition = con_for_cnn)

In [75]:
# lda = LinearDiscriminantAnalysis()
# score = cross_val_score(lda, train_for_cnn[:,:,0], train_label, cv= 5)
# print("LDA+CSP Cross-validation scores:", np.mean(score))

# lda.fit(train_for_cnn[:,:,0], train_label)

# y_pred = lda.predict(train_for_cnn[:,:,0])

# print("Classification TRAIN DATA \n=======================")
# print(classification_report(y_true=train_label, y_pred=y_pred))
# print("Confusion matrix \n=======================")
# print(confusion_matrix(y_true=train_label, y_pred=y_pred))

# y_pred = lda.predict(test_for_cnn[:,:,0])

# print("Classification TEST DATA \n=======================")
# print(classification_report(y_true=test_label, y_pred=y_pred))
# print("Confusion matrix \n=======================")
# print(confusion_matrix(y_true=test_label, y_pred=y_pred))

# WLTL

In [76]:
CSP2D_Epoch = AllBCIClass.computeCSPFeatures(EEG_Epochs, target_subject = target_data)

Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5
Estimating covariance using EMPIRICAL
Done.
Computing rank from data with rank='info'
    MAG: rank 5 after 0 projectors applied to 5 channels
Reducing data rank from 5 -> 5


In [77]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy

# Custom loss function
class CustomLossLL1(tf.keras.losses.Loss):
    def __init__(self, lambda_t, model):
        super().__init__()
        self.lambda_t = lambda_t
        self.cross_entropy = CategoricalCrossentropy()
        self.model = model

    def call(self, y_true, y_pred):
        ce_loss = self.cross_entropy(y_true, y_pred)
        ws = self.get_weights_from_model()
        reg_term = self.regularization_term(ws)
        return ce_loss + self.lambda_t * reg_term

    def get_weights_from_model(self):
        model_weights = []
        for layer in self.model.layers:
            if len(layer.get_weights()) > 0:
                model_weights.append(layer.get_weights()[0])
        # return tf.concat([tf.reshape(w, [-1]) for w in model_weights], axis=0)
        return model_weights

    def regularization_term(self, ws):
        reg_term = tf.pow(tf.norm(ws, ord='euclidean'),2)
        return reg_term


# Custom training loop
def custom_train_step(model, optimizer, x, y, custom_loss):
    with tf.GradientTape() as tape:
        y_pred = model(x, training=True) # Perform a forward pass and compute the predictions
        loss = custom_loss(y, y_pred) # Compute the custom loss
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


def train_weight_LL(X_train, y_train, lambd, num_tier=10, learning_rate = 0.01):
    n_classes = np.unique(y_train).size
    _, sequential_indices = np.unique(y_train, return_inverse=True)
    y_one_hot = tf.keras.utils.to_categorical(sequential_indices, num_classes=n_classes)

    lambda_t = lambd  # Regularization parameter

    model = Sequential([
        Dense(n_classes, input_shape=(X_train.shape[1],), activation='softmax')  # Adjust input_shape to match the number of features in X
    ])

    # Compile the model
    optimizer = Adam(learning_rate)
    custom_loss = CustomLossLL1(lambda_t, model)

    # Custom training loop
    epochs = num_tier
    lowest_loss = float('inf')
    best_weights = None
    for epoch in range(epochs):
        loss = custom_train_step(model, optimizer, X_train, y_one_hot, custom_loss)
        loss_value = loss.numpy()
        if loss_value < lowest_loss:
            lowest_loss = loss_value
            best_weights = [layer.get_weights() for layer in model.layers]

        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Loss: {loss.numpy()}")

    # best_weights = [layer.get_weights() for layer in model.layers]

    return best_weights, lowest_loss

def build_clf_params(data, target_subjects ,condition):

    for sub in data.keys():
        if (sub  != target_subjects) and  (sub != target_data_0): #Don't apply weight to target subject
            # Where the tranining data is stored
            if condition == "noEA":
                X = data[sub]['Raw_csp']
                y = data[sub]['Raw_csp_label']
                store_ws = 'ws_Raw'

            else:
                X = data[sub]['EA_csp']
                y = data[sub]['EA_csp_label']
                store_ws = 'ws_EA'

            weights, loss = train_weight_LL(X_train=X, y_train=y, lambd= 0.1, num_tier=1000, learning_rate= 0.005)
            print("weights of ", str(sub), ": ", weights)
            print("Lowest loss of ", str(sub), ": ", loss)
            data[sub][store_ws] = weights

build_clf_params(CSP2D_Epoch, target_subjects= target_data ,condition = condition_wLTL)

Epoch 0, Loss: 4.595022678375244
Epoch 20, Loss: 2.323911428451538
Epoch 40, Loss: 1.499566912651062
Epoch 60, Loss: 1.423752784729004
Epoch 80, Loss: 1.4065316915512085
Epoch 100, Loss: 1.4046186208724976
Epoch 120, Loss: 1.4024800062179565
Epoch 140, Loss: 1.4002559185028076
Epoch 160, Loss: 1.3979237079620361
Epoch 180, Loss: 1.3955087661743164
Epoch 200, Loss: 1.3930059671401978
Epoch 220, Loss: 1.3904439210891724
Epoch 240, Loss: 1.3878345489501953
Epoch 260, Loss: 1.3851951360702515
Epoch 280, Loss: 1.3825395107269287
Epoch 300, Loss: 1.3798829317092896
Epoch 320, Loss: 1.3772389888763428
Epoch 340, Loss: 1.3746213912963867
Epoch 360, Loss: 1.372043251991272
Epoch 380, Loss: 1.3695178031921387
Epoch 400, Loss: 1.367056965827942
Epoch 420, Loss: 1.3646740913391113
Epoch 440, Loss: 1.3623806238174438
Epoch 460, Loss: 1.3601887226104736
Epoch 480, Loss: 1.3581104278564453
Epoch 500, Loss: 1.3561564683914185
Epoch 520, Loss: 1.3543387651443481
Epoch 540, Loss: 1.3526684045791626
Epoc

In [78]:
CSP2D_Epoch[target_data].keys()

dict_keys(['Raw_csp', 'Raw_csp_label', 'EA_csp', 'EA_csp_label'])

In [79]:
# First define the kl divergence
def KL_div(P, Q):
    # First convert to np array
    P = np.array(P)
    Q = np.array(Q)
    
    # Then compute their means, datain shape of samples x feat
    mu_P = np.mean(P, axis=0)
    mu_Q = np.mean(Q, axis=0)    

    
    # Compute their covariance
    sigma_P = np.cov(P, rowvar=False)
    sigma_Q = np.cov(Q, rowvar=False)  

    diff = mu_Q - mu_P

    inv_sigma_Q = np.linalg.inv(sigma_Q)
    term1 = np.dot(np.dot(diff.T, inv_sigma_Q), diff)
    
    # Calculate the trace term trace(Sigma_Q^{-1} * Sigma_P)
    term2 = np.trace(np.dot(inv_sigma_Q, sigma_P))
    
    # Calculate the determinant term ln(det(Sigma_P) / det(Sigma_Q))
    det_sigma0 = np.linalg.det(sigma_P)
    det_sigma1 = np.linalg.det(sigma_Q)

    
    epsilon = 1e-6
    term3 = np.log((det_sigma0+epsilon) / (det_sigma1+epsilon))
    
    print(term3)
    
    # Dimensionality of the data
    K = mu_P.shape[0]
    
    # KL divergence
    kl_div = 0.5 * (term1 + term2 - term3 - K)
    
    return kl_div

In [80]:
# Compute kl divergence of target subject to each source subject
def compute_all_kl_div(data, target_subjects , condition):
    '''
    Parameter:
    data, is the whole data containing target and source data
    '''
    kl_div_score = []

    if condition == "noEA":
        target_data = 'Raw_csp'
        label_name = 'Raw_csp_label'

    else:
        target_data = 'EA_csp'
        label_name = 'EA_csp_label'
        
    # cal P from target data
    label_tgt =  data[target_subjects][label_name]
    P_left =  data[target_subjects][target_data][np.where(label_tgt == 0)]
    P_right = data[target_subjects][target_data][np.where(label_tgt == 1)]
    P_non = data[target_subjects][target_data][np.where(label_tgt == 2)]
    P_feet = data[target_subjects][target_data][np.where(label_tgt == 3)]

    tgt_data = target_subjects + "_test"

    #cal Q from each source subject
    for sub in data.keys():
        if (sub != target_subjects) and (sub != tgt_data):
            label_src =  data[sub][label_name]
            Q_left =  data[sub][target_data][np.where(label_src == 0)]
            Q_right = data[sub][target_data][np.where(label_src == 1)]
            Q_non = data[sub][target_data][np.where(label_src == 2)]
            Q_feet = data[sub][target_data][np.where(label_src == 3)]

            kl_left = KL_div(P_left, Q_left)
            kl_right = KL_div(P_right, Q_right)
            kl_non = KL_div(P_non, Q_non)
            kl_feet = KL_div(P_feet, Q_feet)

            # kl_div = (kl_left + kl_right+ kl_non + kl_feet)/4

            kl_div_temp = [kl_left, kl_right, kl_non, kl_feet]

            kl_div_score.append(kl_div_temp)

    data[target_subjects]['kl_div'] = kl_div_score


compute_all_kl_div(CSP2D_Epoch, target_subjects=target_data_0 ,condition = condition_wLTL) #target_sub for cal KL is calibrate set

0.10229703248862927
0.0018504274013096628
0.049785329237039024
nan
0.10695465379963104
0.0032232506896612396
-0.39569116994468195
nan
0.10877999369091472
-0.08401380581420284
0.026632418549640755
nan
0.10899569936309919
0.007208964217125076
0.05074341020993564
nan
0.06523228002551716
-0.0021310633125235876
0.009826454788384549
nan
0.09940557110932145
-0.004082862409767903
0.03948479473464662
nan
0.10339452938778265
-0.002739120746363367
0.0004964598402161748
nan
0.10637527916649778
0.0068111322218710885
0.044494064534279384
nan
0.08840376882842864
0.006953018742465782
0.018443917415206144
nan
0.10756008213073144
0.003954573871558288
0.04618385807974789
nan
0.07320051543000514
-0.06871055610669395
0.05214084357288975
nan
0.0661325969769259
-0.02918395810920568
-0.02459649027969104
nan
0.10940095753192033
-0.0027621747522611744
-0.07960577585878781
nan
0.10825111011770695
0.004872159232428473
0.05338006894847506
nan
0.08943178455640784
-0.0017225733967886007
0.04841320221691047
nan
0.092

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(
  avg = a.mean(axis, **keepdims_kw)
  sigma_P = np.cov(P, rowvar=False)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  sigma_Q = np.cov(Q, rowvar=False)
  r = _umath_linalg.det(a, signature=signature)


In [81]:
np.array(CSP2D_Epoch[target_data_0]['kl_div'])

array([[ 9.40080461,  3.77642588,  7.05485501,         nan],
       [ 9.89580141, 10.52444587,  0.54988031,         nan],
       [13.28231681,  2.75609206,  3.00377511,         nan],
       [11.24085164,  6.30648442, 11.50730529,         nan],
       [ 6.89217331,  8.82071535,  7.13506949,         nan],
       [10.26395138,  8.18904565,  6.35780678,         nan],
       [ 8.28305816,  5.82317642,  2.76843468,         nan],
       [ 5.31838809, 10.40244743,  6.0642871 ,         nan],
       [ 4.06653579,  9.76273921,  3.05128951,         nan],
       [10.82844831,  7.68417537,  7.48672341,         nan],
       [ 3.21290616,  5.92314285,  9.63718785,         nan],
       [ 3.64185178,  5.47040362,  3.90875732,         nan],
       [13.87786895,  4.43091317,  1.80127685,         nan],
       [ 8.81255711,  6.44232471, 11.04524664,         nan],
       [ 6.81740254,  9.60198579,  6.92681675,         nan],
       [ 6.30709163,  6.14082641,  4.63855151,         nan],
       [14.24036914, 10.

In [82]:
def compute_similarity_weights(data, target_subjects):
    kl = data[target_subjects]['kl_div']
    KL_inv_left = []
    KL_inv_right = []
    KL_inv_non = []
    KL_inv_feet = []

    alpha_s = []
    eps = 0.0001
    
    #equation (9)
    for val in kl:
        if val != 0: 
            KL_inv_left.append(1/((val[0] + eps)**4))
            KL_inv_right.append(1/((val[1] + eps)**4))
            KL_inv_non.append(1/((val[2] + eps)**4))
            KL_inv_feet.append(1/((val[3] + eps)**4))

    print(KL_inv_left)
    print(KL_inv_right)
    
    for i in range(0,len(KL_inv_left)):
        temp = [KL_inv_left[i]/sum(KL_inv_left), KL_inv_right[i]/sum(KL_inv_right), KL_inv_non[i]/sum(KL_inv_non), KL_inv_feet[i]/sum(KL_inv_feet)]
        alpha_s.append(temp)

    alpha_s = np.array(alpha_s)
    print(np.array(alpha_s[:, ~np.isnan(alpha_s).any(axis=0)]))
                
    data[target_subjects]['alpha_s'] = alpha_s[:, ~np.isnan(alpha_s).any(axis=0)]

compute_similarity_weights(CSP2D_Epoch, target_subjects=target_data_0)

[0.0001280328509003737, 0.00010427460664755261, 3.212858370999344e-05, 6.263076003701985e-05, 0.00044314926767848544, 9.009998745714325e-05, 0.00021243016356103039, 0.001249819328721526, 0.0036564436603680078, 7.273091420431783e-05, 0.009383260410525743, 0.005684119984765362, 2.695853912690839e-05, 0.00016579541253403026, 0.00046291234916815246, 0.0006319109684255113, 2.431659782538074e-05, 0.0007858864988799015, 0.0030841763630272637, 0.009443305299787257, 0.0008854984423231425, 0.0026595476282341756, 0.0005434335600775665, 0.00010840298645947809, 0.0007335546927487045, 0.0020265446582339967, 0.03419785120146954, 0.027191036295779767, 0.006302118159841352]
[0.004916206745552706, 8.150542921962119e-05, 0.01732854470043665, 0.0006321543706035462, 0.00016518289585395557, 0.0002223543437557585, 0.0008696217179319007, 8.539671839792355e-05, 0.00011007671315121094, 0.00028680605926043886, 0.000812384906195501, 0.0011165823530840506, 0.0025941067058282345, 0.0005805004917364637, 0.0001176353

In [83]:
def compute_ETL_and_mu_ws(data, target_subjects, condition):

    mu_ws = 0
    temp_ws = 0

    if condition == "noEA":
        ws_name = 'ws_Raw'
    else:
        ws_name = 'ws_EA'

    alpha_s = np.array(data[target_subjects]['alpha_s'])

    tgt_data = target_subjects + "_test"
    index_count = 0

    for sub in data.keys():
        if (sub != target_subjects) and (sub != tgt_data):
            ws = data[sub][ws_name][0][0]
            # mu_ws += ws @ alpha_s  #equation (10)
            # mu_ws += np.dot(ws, np.transpose(alpha_s))
            mu_ws += ws * alpha_s[index_count]
            index_count += 1

    print(np.array(mu_ws))

    index_count = 0
    for sub in data.keys():
        if (sub != target_subjects) and (sub != tgt_data):
            ws = data[sub][ws_name][0][0]
            # ws_min_mu = np.dot((np.dot(ws,np.transpose(alpha_s)) - mu_ws) , np.transpose((np.dot(ws,np.transpose(alpha_s)) - mu_ws)))
            ws_min_mu = np.dot(((ws * alpha_s[index_count]) - mu_ws), np.transpose((ws * alpha_s[index_count]) - mu_ws))
            temp_ws += ws_min_mu #equation (11)
            index_count += 1

    print(np.array(temp_ws))
    
    # den = np.diag(temp_ws) #get array in diagonal line

    den = temp_ws
    nom = np.trace(temp_ws) #Return the sum along diagonals of the array.
    Sigma_TL = den/nom


    data[target_subjects]['Sigma_TL'] = Sigma_TL
    data[target_subjects]['mu_ws'] = mu_ws

compute_ETL_and_mu_ws(CSP2D_Epoch, target_subjects = target_data_0, condition=condition_wLTL)

[[-0.20361998  0.63783641  0.02297078]
 [-0.30742249 -0.70227308 -0.10719426]
 [-0.09383623  0.72695871  0.24928293]
 [ 0.15439319 -0.35298721  0.55261605]
 [ 0.17576212 -0.1100445  -0.77231342]]
[[ 12.53385487 -10.90939753  13.66955605  -6.79387663  -3.44492415]
 [-10.90939753  16.71174699 -14.25187706   3.98091897   2.9850123 ]
 [ 13.66955605 -14.25187706  16.81814877  -3.70406489  -8.10560135]
 [ -6.79387663   3.98091897  -3.70406489  12.72591859 -10.12158134]
 [ -3.44492415   2.9850123   -8.10560135 -10.12158134  17.87512982]]


In [84]:
print(np.array(CSP2D_Epoch[target_data_0]['Sigma_TL']).shape)
print(np.array(CSP2D_Epoch[target_data_0]['mu_ws']).shape)

(5, 5)
(5, 3)


In [85]:
# Custom loss function
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

class CustomLossLL2(tf.keras.losses.Loss):
    
    def __init__(self, lambda_t, model, mu, sigma_TL):
        super().__init__()
        self.lambda_t = lambda_t
        self.cross_entropy = CategoricalCrossentropy()
        self.model = model
        self.mu = tf.convert_to_tensor(mu, dtype=tf.float32)
        self.sigma_TL = tf.convert_to_tensor(sigma_TL, dtype=tf.float32)

    def call(self, y_true, y_pred):
        ce_loss = self.cross_entropy(y_true, y_pred)
        wt = self.get_weights_from_model()
        reg_term = self.regularization_term(wt)

        return ce_loss + (self.lambda_t * tf.linalg.matmul(reg_term, wt))

    def get_weights_from_model(self):
        model_weights = []
        for layer in self.model.layers:
            if len(layer.get_weights()) > 0:
                model_weights.append(layer.get_weights()[0])
        # return tf.concat([tf.reshape(w, [-1]) for w in model_weights], axis=0)
        return model_weights

    def regularization_term(self, wt):
        diff = wt - self.mu
        reg_term = 0.5 * tf.linalg.matmul(tf.linalg.matmul(tf.linalg.inv(self.sigma_TL), diff[0]), tf.transpose(diff[0]))
        reg_term += 0.5 * tf.math.log(tf.linalg.det(self.sigma_TL))
        return reg_term


# Custom training loop
def custom_train_step(model, optimizer, x, y, custom_loss):
    with tf.GradientTape() as tape:
        y_pred = model(x, training=True) # Perform a forward pass and compute the predictions
        loss = custom_loss(y, y_pred) # Compute the custom loss
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


def train_weight_LL2(X_train, y_train, lambd, mu, sigma_TL, num_tier=10, learning_rate = 0.01):
    n_classes = np.unique(y_train).size
    _, sequential_indices = np.unique(y_train, return_inverse=True)
    y_one_hot = tf.keras.utils.to_categorical(sequential_indices, num_classes=n_classes)

    lambda_t = lambd  # Regularization parameter

    model = Sequential([
        Dense(n_classes, input_shape=(X_train.shape[1],), activation='softmax')  # Adjust input_shape to match the number of features in X
    ])

    # Compile the model
    optimizer = Adam(learning_rate)
    custom_loss = CustomLossLL2(lambda_t, model, mu, sigma_TL)

    # Custom training loop
    epochs = num_tier
    lowest_loss = float('inf')
    best_weights = None
    best_model = None

    for epoch in range(epochs):
        loss = custom_train_step(model, optimizer, X_train, y_one_hot, custom_loss)
        loss_value = loss.numpy()
        if epoch % 20 == 0:
            print(f"Epoch {epoch}, Loss: {loss.numpy()}")

        if (abs(loss_value) < lowest_loss):
            lowest_loss = abs(loss_value)
            best_model = model
            best_weights = [layer.get_weights() for layer in model.layers]

    return best_model, best_weights, lowest_loss

def GetConfusionMatrix(model, X_train, X_test, y_train, y_test):
    y_pred_prob = model.predict(X_train)
    y_pred = np.argmax(y_pred_prob, axis=1)

    print("Classification TRAIN DATA \n=======================")
    print(classification_report(y_true= y_train, y_pred=y_pred))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true= y_train, y_pred=y_pred))

    y_pred_prob = model.predict(X_test)
    y_pred = np.argmax(y_pred_prob, axis=1)
    print("Classification TEST DATA \n=======================")
    print(classification_report(y_true=y_test, y_pred=y_pred))
    print("Confusion matrix \n=======================")
    print(confusion_matrix(y_true=y_test, y_pred=y_pred))


def tgt_test_wLTL(data, target_subjects ,condition):
        tgt_data = target_subjects + "_test"

        if condition == "noEA":
            X = data[target_subjects]['Raw_csp']
            y = data[target_subjects]['Raw_csp_label']
            X_test = data[tgt_data]['Raw_csp']
            y_test = data[tgt_data]['Raw_csp_label']
            store_ws = 'wt_Raw'

        else:
            X = data[target_subjects]['EA_csp']
            y = data[target_subjects]['EA_csp_label']
            X_test = data[tgt_data]['EA_csp']
            y_test = data[tgt_data]['EA_csp_label']
            store_ws = 'wt_EA'

        mu = data[target_subjects]['mu_ws']
        sigma_TL = data[target_subjects]['Sigma_TL']

        X_train = X
        y_train = y
        
        model, weights, loss = train_weight_LL2(X_train=X_train, y_train=y_train, mu =mu, sigma_TL=sigma_TL, lambd= 0.1, num_tier=5000, learning_rate= 0.01)
        print("weights of ", str(target_subjects), ": ", weights)
        print("loss of ", str(target_subjects), ": ", loss)
        data[target_subjects][store_ws] = weights

        GetConfusionMatrix(model, X_train, X_test, y_train, y_test)

tgt_test_wLTL(CSP2D_Epoch, target_subjects= target_data_0 ,condition = condition_wLTL)

Epoch 0, Loss: 167.94546508789062
Epoch 20, Loss: 74.54048156738281
Epoch 40, Loss: 65.38609313964844
Epoch 60, Loss: 73.4594955444336
Epoch 80, Loss: 69.89293670654297
Epoch 100, Loss: 70.67900848388672
Epoch 120, Loss: 69.95596313476562
Epoch 140, Loss: 69.67904663085938
Epoch 160, Loss: 69.31879425048828
Epoch 180, Loss: 68.95577239990234
Epoch 200, Loss: 68.6186752319336
Epoch 220, Loss: 68.28462219238281
Epoch 240, Loss: 67.96063232421875
Epoch 260, Loss: 67.64575958251953
Epoch 280, Loss: 67.3375015258789
Epoch 300, Loss: 67.03396606445312
Epoch 320, Loss: 66.73272705078125
Epoch 340, Loss: 66.43109893798828
Epoch 360, Loss: 66.1263427734375
Epoch 380, Loss: 65.8156967163086
Epoch 400, Loss: 65.49644470214844
Epoch 420, Loss: 65.16597747802734
Epoch 440, Loss: 64.82178497314453
Epoch 460, Loss: 64.46162414550781
Epoch 480, Loss: 64.0833511352539
Epoch 500, Loss: 63.68505096435547
Epoch 520, Loss: 63.26509475708008
Epoch 540, Loss: 62.82196807861328
Epoch 560, Loss: 62.35437011718