In [1]:
import numpy as np
import pickle
import sys
from typing import List
import torch
import os

In [2]:
participant_data_path = './EEG/number'
processed_data_path = './EEG/processed_data'
split_data_path = './EEG/split_data'

In [3]:
ps = os.listdir(participant_data_path)
participants=[]
for p in ps:
    p= p[:3]
    if p not in participants:
        participants.append(p)
        
participants= sorted(participants)
# participants, len(participants)

In [4]:
load_name=f"userfold_data_scaled_p_dictionary-number"
data_dir = "./EEG/split_data/standard_scaled"
try:
    raw_user_fold= pickle.load(open(os.path.join(data_dir, f"{load_name}.pkl"), "rb"))
except:
    print(f"pickle file does not exist. Use EEG-Preprocess.ipynb and EEG-Split.ipynb to save data setting.")
    sys.exit()

In [None]:
from utilities.userfold_framework import *
from Models.AR_EEG_models import *
import Models.model_func as Model_Func
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List
from torcheeg.models import EEGNet
from torch import nn


DEVICE= torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

learning_rate = 0.00005
batch_size = 64
n_epochs = 300
transpose_channels=True
participants_dictionary=[]
# participants_online_dictionary=[]
participants_grads_dictionary={}
b_acc_list=[]
c0_acc_list=[]
c1_acc_list=[]

# EPOCH=[
    
# ]

for i in range(len(participants)):

    train_dataloader, val_dataloader, classes, input_dim, class_ratio= user_fold_load(i,
                                                                                      raw_user_fold,
                                                                                      participants,
                                                                                      batch_size=batch_size,
                                                                                      transpose_channels=transpose_channels)

    classifier= EEGNet(
        chunk_size=input_dim[1],
        num_electrodes=input_dim[0],
        num_classes=classes,
        kernel_1= 32,
        kernel_2=32,
        F1=8,
        F2=16,
        dropout=0.5
    ).to(DEVICE)
    
#     resnet = ResNetPlus(input_dim[0], classes, bn_1st=False)
#     softmax_activation = nn.LogSoftmax(dim=1)
#     classifier = nn.Sequential(resnet, softmax_activation).to(DEVICE)

    
#     classifier = DataGliderBasic_Model(DEVICE, input_dim, classes)
#     classifier.to(DEVICE)
    
#     optimizer= torch.optim.RMSprop(classifier.parameters(), lr=learning_rate)
    
    
    criterion= torch.nn.CrossEntropyLoss(weight=torch.tensor(class_ratio, dtype=torch.float).to(DEVICE))
#     criterion = nn.NLLLoss(weight=torch.tensor(class_ratio, dtype=torch.float).to(DEVICE))
        
    saved_dir= "./EEG/saved_models/Userfold/run3"
    model = EEGNet_IE_HP_Wrapper(DEVICE, classifier, input_dim).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

    
    train_func= eeg_train
    model.training_procedure(iteration=n_epochs,
                                    train_dataloader=train_dataloader,
                                     val_dataloader=val_dataloader,
                                     print_cycle=2,
                                     path=f"./dictionary/intermdiate_dicts",
                                     loss_func=criterion,
                                     optimiser=optimizer, #scheduler=scheduler,
                                     train_func=train_func
                                    )
    if model.epoch == n_epochs+1:
        EPOCH= n_epochs
    else:
        EPOCH= model.epoch
    
    torch.save(model.state_dict(), 
           os.path.join(
               saved_dir, f"Userfold-{participants[i]}-EEGNet-Weight_Multivariate-e{EPOCH}.pt"
           )
    )

    pickle.dump( model.return_IE_weights(), 
                open(f"{saved_dir}/Userfold-{participants[i]}-EEGNet-Weight_Multivariate-w-e{EPOCH}.pkl", "wb") 
               )    

# OR
#     model.load_state_dict(
#     torch.load(
#         open(
#             os.path.join(
#                 saved_dir, f"Userfold-{participants[i]}-EEGNet-Weight_Multivariate-e{n_epochs}.pt"
#             ), "rb"
#         )
#               )
#     )
    
    prediction, dictionary= model.prediction_procedure(val_dataloader, dict_flag=True)
    
    ys= np.concatenate([y.detach().cpu().numpy() for x, y in val_dataloader])
    
    c0_acc, c1_acc, b_acc= calculate_accuracy(ys, prediction)
    print("c0_acc", c0_acc, ", c1_acc", c1_acc, ", b_acc", b_acc)
    b_acc_list.append(b_acc)
    c0_acc_list.append(c0_acc)
    c1_acc_list.append(c1_acc)
    participants_dictionary.append(dictionary)
    

tmp=[]
for i, dictionary in enumerate(participants_dictionary):
    print(f"User {participants[i]} f1: {dictionary['weighted avg']['f1-score']} acc: {dictionary['accuracy']}")
    print(f" c0: {c0_acc_list[i]} c1: {c1_acc_list[i]} bacc: {b_acc_list[i]}")
    tmp.append(dictionary['weighted avg']['f1-score'])

print(f"average {np.mean(tmp)}")
print()
print(np.array(b_acc_list).mean())
print(np.array(c1_acc_list).mean())
print(np.array(c0_acc_list).mean())


Iterations:   0%|                                           | 0/300 [00:00<?, ?it/s]

Epoch  0 , loss 0.7025379678782295


Iterations:   0%|                                   | 1/300 [00:03<16:54,  3.39s/it]

Epoch:  0
t_loss:  0.7025379678782295 , v_loss:  0.6931550900141398
t_acc:  0.5188297541238718 , v_acc:  0.4720496894409938
t_recall:  0.5109468714401394 , v_recall:  0.5951351351351352
t_prec:  0.509293582115059 , v_prec:  0.6222788327929597
t_f:  0.4943464257211265 , v_f:  0.46678355737385546
////////


Iterations:   1%|▏                                  | 2/300 [00:04<09:04,  1.83s/it]

Epoch  1 , loss 0.6706051078497195
Epoch  2 , loss 0.6459513575422997


Iterations:   1%|▎                                  | 3/300 [00:05<07:50,  1.58s/it]

Epoch:  2
t_loss:  0.6459513575422997 , v_loss:  0.6896252433458964
t_acc:  0.5969498910675382 , v_acc:  0.4782608695652174
t_recall:  0.5251190623241515 , v_recall:  0.5611711711711712
t_prec:  0.5250189102242798 , v_prec:  0.5618622448979591
t_f:  0.5250659495079769 , v_f:  0.47824074074074074
////////


Iterations:   1%|▍                                  | 4/300 [00:06<06:18,  1.28s/it]

Epoch  3 , loss 0.6251586535397697
Epoch  4 , loss 0.606989299549776


Iterations:   2%|▌                                  | 5/300 [00:07<06:03,  1.23s/it]

Epoch:  4
t_loss:  0.606989299549776 , v_loss:  0.6717332402865092
t_acc:  0.6314970432617492 , v_acc:  0.6708074534161491
t_recall:  0.5177474917994538 , v_recall:  0.5963963963963964
t_prec:  0.5232858128995468 , v_prec:  0.6054394954670871
t_f:  0.5129788927204517 , v_f:  0.5993332394233929
////////


Iterations:   2%|▋                                  | 6/300 [00:08<05:06,  1.04s/it]

Epoch  5 , loss 0.5881787059353847
Epoch  6 , loss 0.5667766268346824


Iterations:   2%|▊                                  | 7/300 [00:09<05:23,  1.10s/it]

Epoch:  6
t_loss:  0.5667766268346824 , v_loss:  0.6506198843320211
t_acc:  0.6676003734827264 , v_acc:  0.6956521739130435
t_recall:  0.5221332400049409 , v_recall:  0.5264864864864865
t_prec:  0.5445345638918346 , v_prec:  0.6363636363636364
t_f:  0.5024792784516776 , v_f:  0.477722608407812
////////


Iterations:   3%|▉                                  | 8/300 [00:09<04:44,  1.03it/s]

Epoch  7 , loss 0.5508239497156704
Epoch  8 , loss 0.5360890130201975


Iterations:   3%|█                                  | 9/300 [00:11<04:59,  1.03s/it]

Epoch:  8
t_loss:  0.5360890130201975 , v_loss:  0.6380334198474884
t_acc:  0.6878306878306878 , v_acc:  0.7018633540372671
t_recall:  0.5254608089375661 , v_recall:  0.52
t_prec:  0.5777003065960026 , v_prec:  0.8490566037735849
t_f:  0.4922845547886018 , v_f:  0.44957264957264953
////////


Iterations:   3%|█▏                                | 10/300 [00:12<04:44,  1.02it/s]

Epoch  9 , loss 0.5296928321613985
Epoch  10 , loss 0.5166183914624008


Iterations:   4%|█▏                                | 11/300 [00:13<04:58,  1.03s/it]

Epoch:  10
t_loss:  0.5166183914624008 , v_loss:  0.6333273549874624
t_acc:  0.6862745098039216 , v_acc:  0.6956521739130435
t_recall:  0.5082403480600046 , v_recall:  0.51
t_prec:  0.5436563528573437 , v_prec:  0.846875
t_f:  0.45191878112113243 , v_f:  0.42920193907821425
////////


Iterations:   4%|█▎                                | 12/300 [00:13<04:25,  1.09it/s]

Epoch  11 , loss 0.5115708586047677
Epoch  12 , loss 0.49491902484613304


Iterations:   4%|█▍                                | 13/300 [00:15<04:53,  1.02s/it]

Epoch:  12
t_loss:  0.49491902484613304 , v_loss:  0.6317848165829977
t_acc:  0.6931216931216931 , v_acc:  0.6894409937888198
t_recall:  0.5085615075280328 , v_recall:  0.5
t_prec:  0.5746648579636131 , v_prec:  0.3447204968944099
t_f:  0.4421859285156877 , v_f:  0.4080882352941176
////////


Iterations:   5%|█▌                                | 14/300 [00:15<04:32,  1.05it/s]

Epoch  13 , loss 0.48735906505117227
Epoch  14 , loss 0.4811030836666332


Iterations:   5%|█▋                                | 15/300 [00:17<04:48,  1.01s/it]

Epoch:  14
t_loss:  0.4811030836666332 , v_loss:  0.6314314206441244
t_acc:  0.6909430438842203 , v_acc:  0.6894409937888198
t_recall:  0.503832640232772 , v_recall:  0.5
t_prec:  0.5409417456749096 , v_prec:  0.3447204968944099
t_f:  0.43173450597019947 , v_f:  0.4080882352941176
////////


Iterations:   5%|█▊                                | 16/300 [00:17<04:20,  1.09it/s]

Epoch  15 , loss 0.47612931564742444
Epoch  16 , loss 0.4673481735528684


Iterations:   6%|█▉                                | 17/300 [00:18<04:48,  1.02s/it]

Epoch:  16
t_loss:  0.4673481735528684 , v_loss:  0.6315363446871439
t_acc:  0.6959228135698724 , v_acc:  0.6894409937888198
t_recall:  0.5042491868077572 , v_recall:  0.5
t_prec:  0.6079723791588199 , v_prec:  0.3447204968944099
t_f:  0.4238178228943825 , v_f:  0.4080882352941176
////////


Iterations:   6%|██                                | 18/300 [00:19<04:18,  1.09it/s]

Epoch  17 , loss 0.4682005264011084
Epoch  18 , loss 0.466853213076498


Iterations:   6%|██▏                               | 19/300 [00:20<04:36,  1.02it/s]

Epoch:  18
t_loss:  0.466853213076498 , v_loss:  0.6332898040612539
t_acc:  0.6949891067538126 , v_acc:  0.6894409937888198
t_recall:  0.5035780458681599 , v_recall:  0.5
t_prec:  0.5819038642789821 , v_prec:  0.3447204968944099
t_f:  0.4234496124031008 , v_f:  0.4080882352941176
////////


Iterations:   7%|██▎                               | 20/300 [00:21<04:12,  1.11it/s]

Epoch  19 , loss 0.4585847638401331
Epoch  20 , loss 0.4563316054203931


Iterations:   7%|██▍                               | 21/300 [00:22<04:30,  1.03it/s]

Epoch:  20
t_loss:  0.4563316054203931 , v_loss:  0.6345209379990896
t_acc:  0.6940553999377529 , v_acc:  0.6894409937888198
t_recall:  0.5017567697396412 , v_recall:  0.5
t_prec:  0.5481806775407779 , v_prec:  0.3447204968944099
t_f:  0.41933759848979707 , v_f:  0.4080882352941176
////////


Iterations:   7%|██▍                               | 22/300 [00:23<04:20,  1.07it/s]

Epoch  21 , loss 0.4523195863938799
Epoch  22 , loss 0.4507758909580754


Iterations:   8%|██▌                               | 23/300 [00:24<04:40,  1.01s/it]

Epoch:  22
t_loss:  0.4507758909580754 , v_loss:  0.6348722477753957
t_acc:  0.6984126984126984 , v_acc:  0.6894409937888198
t_recall:  0.5074765649661684 , v_recall:  0.5
t_prec:  0.67698947322821 , v_prec:  0.3447204968944099
t_f:  0.42946028474236353 , v_f:  0.4080882352941176
////////


Iterations:   8%|██▋                               | 24/300 [00:25<04:14,  1.08it/s]

Epoch  23 , loss 0.4502906571416294
Epoch  24 , loss 0.44361085050246296


Iterations:   8%|██▊                               | 25/300 [00:26<04:35,  1.00s/it]

Epoch:  24
t_loss:  0.44361085050246296 , v_loss:  0.6344233155250549
t_acc:  0.6965452847805789 , v_acc:  0.6894409937888198
t_recall:  0.5038340127091311 , v_recall:  0.5
t_prec:  0.6380961012424612 , v_prec:  0.3447204968944099
t_f:  0.42123755684071923 , v_f:  0.4080882352941176
////////


Iterations:   9%|██▉                               | 26/300 [00:27<04:09,  1.10it/s]

Epoch  25 , loss 0.4461994223734912
Epoch  26 , loss 0.44604378179007886


Iterations:   9%|███                               | 27/300 [00:28<04:38,  1.02s/it]

Epoch:  26
t_loss:  0.44604378179007886 , v_loss:  0.6361049513022105
t_acc:  0.6949891067538126 , v_acc:  0.6894409937888198
t_recall:  0.5012777754903172 , v_recall:  0.5
t_prec:  0.562363238512035 , v_prec:  0.3447204968944099
t_f:  0.4158754021869472 , v_f:  0.4080882352941176
////////


Iterations:   9%|███▏                              | 28/300 [00:29<04:18,  1.05it/s]

Epoch  27 , loss 0.4488006970461677
Epoch  28 , loss 0.4464272582063488


Iterations:  10%|███▎                              | 29/300 [00:30<04:33,  1.01s/it]

Epoch:  28
t_loss:  0.4464272582063488 , v_loss:  0.6326283911863962
t_acc:  0.6956115779645191 , v_acc:  0.6894409937888198
t_recall:  0.5020127365806124 , v_recall:  0.5
t_prec:  0.5982338230697093 , v_prec:  0.3447204968944099
t_f:  0.4170674932028922 , v_f:  0.4080882352941176
////////


Iterations:  10%|███▍                              | 30/300 [00:31<04:05,  1.10it/s]

Epoch  29 , loss 0.44517050888024123
Epoch  30 , loss 0.45177117808192385


Iterations:  10%|███▌                              | 31/300 [00:32<04:27,  1.01it/s]

Epoch:  30
t_loss:  0.45177117808192385 , v_loss:  0.6294560382763544
t_acc:  0.6965452847805789 , v_acc:  0.6894409937888198
t_recall:  0.5035464789119007 , v_recall:  0.5
t_prec:  0.6426783479349187 , v_prec:  0.3447204968944099
t_f:  0.4202877765739024 , v_f:  0.4080882352941176
////////


Iterations:  11%|███▋                              | 32/300 [00:33<04:08,  1.08it/s]

Epoch  31 , loss 0.43753409035065594
Epoch  32 , loss 0.4383920933686051


Iterations:  11%|███▋                              | 33/300 [00:34<04:26,  1.00it/s]

Epoch:  32
t_loss:  0.4383920933686051 , v_loss:  0.6254940380652746
t_acc:  0.6968565203859322 , v_acc:  0.6956521739130435
t_recall:  0.5034826587612028 , v_recall:  0.51
t_prec:  0.6699749921850578 , v_prec:  0.846875
t_f:  0.4194516752347822 , v_f:  0.42920193907821425
////////


Iterations:  11%|███▊                              | 34/300 [00:34<03:59,  1.11it/s]

Epoch  33 , loss 0.4359377914784001
Epoch  34 , loss 0.44350121652378754


Iterations:  12%|███▉                              | 35/300 [00:36<04:27,  1.01s/it]

Epoch:  34
t_loss:  0.44350121652378754 , v_loss:  0.6207669973373413
t_acc:  0.6962340491752257 , v_acc:  0.6956521739130435
t_recall:  0.5033227652653683 , v_recall:  0.51
t_prec:  0.6262910798122066 , v_prec:  0.846875
t_f:  0.4201690740864009 , v_f:  0.42920193907821425
////////


Iterations:  12%|████                              | 36/300 [00:37<04:08,  1.06it/s]

Epoch  35 , loss 0.4410121908374861
Epoch  36 , loss 0.43810248842426375


Iterations:  12%|████▏                             | 37/300 [00:38<04:23,  1.00s/it]

Epoch:  36
t_loss:  0.43810248842426375 , v_loss:  0.6131264766057333
t_acc:  0.6993464052287581 , v_acc:  0.6956521739130435
t_recall:  0.5075726383113051 , v_recall:  0.51
t_prec:  0.7357838808011168 , v_prec:  0.846875
t_f:  0.4279841503870254 , v_f:  0.42920193907821425
////////


Iterations:  13%|████▎                             | 38/300 [00:38<03:59,  1.09it/s]

Epoch  37 , loss 0.4374657875182582
Epoch  38 , loss 0.43251505377245886


Iterations:  13%|████▍                             | 39/300 [00:40<04:20,  1.00it/s]

Epoch:  38
t_loss:  0.43251505377245886 , v_loss:  0.610899622241656
t_acc:  0.6968565203859322 , v_acc:  0.6956521739130435
t_recall:  0.5054953953418152 , v_recall:  0.51
t_prec:  0.6346938775510205 , v_prec:  0.846875
t_f:  0.42605217404062706 , v_f:  0.42920193907821425
////////


Iterations:  13%|████▌                             | 40/300 [00:40<03:58,  1.09it/s]

Epoch  39 , loss 0.4319348446294373
Epoch  40 , loss 0.43687473734219867


Iterations:  14%|████▋                             | 41/300 [00:42<04:23,  1.02s/it]

Epoch:  40
t_loss:  0.43687473734219867 , v_loss:  0.6084940532843272
t_acc:  0.6971677559912854 , v_acc:  0.6956521739130435
t_recall:  0.5054315751911174 , v_recall:  0.51
t_prec:  0.6489648682559599 , v_prec:  0.846875
t_f:  0.4252446422488022 , v_f:  0.42920193907821425
////////


Iterations:  14%|████▊                             | 42/300 [00:42<03:56,  1.09it/s]

Epoch  41 , loss 0.4365409186073378
Epoch  42 , loss 0.4367132379728205


Iterations:  14%|████▊                             | 43/300 [00:43<04:14,  1.01it/s]

Epoch:  42
t_loss:  0.4367132379728205 , v_loss:  0.6057532727718353
t_acc:  0.699035169623405 , v_acc:  0.6956521739130435
t_recall:  0.508499059853694 , v_recall:  0.51
t_prec:  0.688332556112001 , v_prec:  0.846875
t_f:  0.431556640045754 , v_f:  0.42920193907821425
////////


Iterations:  15%|████▉                             | 44/300 [00:44<03:52,  1.10it/s]

Epoch  43 , loss 0.4317142414111717
Epoch  44 , loss 0.4318320464854147


Iterations:  15%|█████                             | 45/300 [00:45<04:13,  1.01it/s]

Epoch:  44
t_loss:  0.4318320464854147 , v_loss:  0.602776030699412
t_acc:  0.6959228135698724 , v_acc:  0.6956521739130435
t_recall:  0.5033865854160662 , v_recall:  0.51
t_prec:  0.6104323308270676 , v_prec:  0.846875
t_f:  0.4209981125404011 , v_f:  0.42920193907821425
////////


Iterations:  15%|█████▏                            | 46/300 [00:46<04:06,  1.03it/s]

Epoch  45 , loss 0.4327730559835247
Epoch  46 , loss 0.4373088403075349


Iterations:  16%|█████▎                            | 47/300 [00:47<04:26,  1.05s/it]

Epoch:  46
t_loss:  0.4373088403075349 , v_loss:  0.5984770258267721
t_acc:  0.6959228135698724 , v_acc:  0.7018633540372671
t_recall:  0.5042491868077572 , v_recall:  0.52
t_prec:  0.6079723791588199 , v_prec:  0.8490566037735849
t_f:  0.4238178228943825 , v_f:  0.44957264957264953
////////


Iterations:  16%|█████▍                            | 48/300 [00:48<04:11,  1.00it/s]

Epoch  47 , loss 0.4329748889979194
Epoch  48 , loss 0.4292204841679218


Iterations:  16%|█████▌                            | 49/300 [00:49<04:19,  1.03s/it]

Epoch:  48
t_loss:  0.4292204841679218 , v_loss:  0.5962955405314764
t_acc:  0.7005913476501712 , v_acc:  0.7018633540372671
t_recall:  0.5107677632752776 , v_recall:  0.52
t_prec:  0.7177571563384713 , v_prec:  0.8490566037735849
t_f:  0.4358598188563978 , v_f:  0.44957264957264953
////////


Iterations:  17%|█████▋                            | 50/300 [00:50<03:55,  1.06it/s]

Epoch  49 , loss 0.4286900431502099
Epoch  50 , loss 0.43208782462512746


Iterations:  17%|█████▊                            | 51/300 [00:51<04:19,  1.04s/it]

Epoch:  50
t_loss:  0.43208782462512746 , v_loss:  0.5915191223224004
t_acc:  0.6962340491752257 , v_acc:  0.7080745341614907
t_recall:  0.50476043425152 , v_recall:  0.53
t_prec:  0.6166797488226059 , v_prec:  0.8512658227848101
t_f:  0.4248736364103204 , v_f:  0.4692431787893666
////////


Iterations:  17%|█████▉                            | 52/300 [00:52<03:58,  1.04it/s]

Epoch  51 , loss 0.4285354246111477
Epoch  52 , loss 0.42549822961582856


Iterations:  18%|██████                            | 53/300 [00:53<04:11,  1.02s/it]

Epoch:  52
t_loss:  0.42549822961582856 , v_loss:  0.5882058441638947
t_acc:  0.6984126984126984 , v_acc:  0.7080745341614907
t_recall:  0.5092017677495505 , v_recall:  0.53
t_prec:  0.6546573678221019 , v_prec:  0.8512658227848101
t_f:  0.43492779051180885 , v_f:  0.4692431787893666
////////


Iterations:  18%|██████                            | 54/300 [00:54<03:47,  1.08it/s]

Epoch  53 , loss 0.4335464916977228
Epoch  54 , loss 0.4314744776370479


Iterations:  18%|██████▏                           | 55/300 [00:55<04:19,  1.06s/it]

Epoch:  54
t_loss:  0.4314744776370479 , v_loss:  0.586230273048083
t_acc:  0.6999688764394647 , v_acc:  0.7080745341614907
t_recall:  0.5083075994016003 , v_recall:  0.53
t_prec:  0.7586678442209625 , v_prec:  0.8512658227848101
t_f:  0.42916844821231115 , v_f:  0.4692431787893666
////////


Iterations:  19%|██████▎                           | 56/300 [00:56<03:52,  1.05it/s]

Epoch  55 , loss 0.4306218594897027
Epoch  56 , loss 0.4279610777602476


Iterations:  19%|██████▍                           | 57/300 [00:57<04:06,  1.02s/it]

Epoch:  56
t_loss:  0.4279610777602476 , v_loss:  0.5817717015743256
t_acc:  0.7009025832555245 , v_acc:  0.7080745341614907
t_recall:  0.5121416121107314 , v_recall:  0.53
t_prec:  0.7040683726509396 , v_prec:  0.8512658227848101
t_f:  0.4395929893517526 , v_f:  0.4692431787893666
////////


Iterations:  19%|██████▌                           | 58/300 [00:58<03:46,  1.07it/s]

Epoch  57 , loss 0.429557451430489
Epoch  58 , loss 0.42403709537842693


Iterations:  20%|██████▋                           | 59/300 [00:59<03:58,  1.01it/s]

Epoch:  58
t_loss:  0.42403709537842693 , v_loss:  0.5774204581975937
t_acc:  0.6984126984126984 , v_acc:  0.7080745341614907
t_recall:  0.5092017677495505 , v_recall:  0.53
t_prec:  0.6546573678221019 , v_prec:  0.8512658227848101
t_f:  0.43492779051180885 , v_f:  0.4692431787893666
////////


Iterations:  20%|██████▊                           | 60/300 [01:00<03:39,  1.09it/s]

Epoch  59 , loss 0.42752207669557307
Epoch  60 , loss 0.42739493122287825


Iterations:  20%|██████▉                           | 61/300 [01:01<04:02,  1.01s/it]

Epoch:  60
t_loss:  0.42739493122287825 , v_loss:  0.5776236255963644
t_acc:  0.6977902272019919 , v_acc:  0.7080745341614907
t_recall:  0.5101920094426373 , v_recall:  0.53
t_prec:  0.6330188679245283 , v_prec:  0.8512658227848101
t_f:  0.43910723098317317 , v_f:  0.4692431787893666
////////


Iterations:  21%|███████                           | 62/300 [01:02<03:44,  1.06it/s]

Epoch  61 , loss 0.4254034179098466
Epoch  62 , loss 0.42570928615682263


Iterations:  21%|███████▏                          | 63/300 [01:03<04:00,  1.01s/it]

Epoch:  62
t_loss:  0.42570928615682263 , v_loss:  0.5760328223307928
t_acc:  0.7018362900715842 , v_acc:  0.7080745341614907
t_recall:  0.5133878206447894 , v_recall:  0.53
t_prec:  0.7197256385998108 , v_prec:  0.8512658227848101
t_f:  0.44178690344062155 , v_f:  0.4692431787893666
////////


Iterations:  21%|███████▎                          | 64/300 [01:04<03:35,  1.10it/s]

Epoch  63 , loss 0.4249820113182068
Epoch  64 , loss 0.4215603868166606


Iterations:  22%|███████▎                          | 65/300 [01:05<04:08,  1.06s/it]

Epoch:  64
t_loss:  0.4215603868166606 , v_loss:  0.5747045427560806
t_acc:  0.6984126984126984 , v_acc:  0.7080745341614907
t_recall:  0.5092017677495505 , v_recall:  0.53
t_prec:  0.6546573678221019 , v_prec:  0.8512658227848101
t_f:  0.43492779051180885 , v_f:  0.4692431787893666
////////


Iterations:  22%|███████▍                          | 66/300 [01:06<03:47,  1.03it/s]

Epoch  65 , loss 0.42458982152097363
Epoch  66 , loss 0.4216153095750248


Iterations:  22%|███████▌                          | 67/300 [01:07<03:57,  1.02s/it]

Epoch:  66
t_loss:  0.4216153095750248 , v_loss:  0.572292427221934
t_acc:  0.7049486461251168 , v_acc:  0.7018633540372671
t_recall:  0.5187878288796476 , v_recall:  0.5254954954954955
t_prec:  0.7407407407407407 , v_prec:  0.7253184713375795
t_f:  0.45282225728846004 , v_f:  0.4660033167495854
////////


Iterations:  23%|███████▋                          | 68/300 [01:08<03:34,  1.08it/s]

Epoch  67 , loss 0.42378677573858525
Epoch  68 , loss 0.423418847953572


Iterations:  23%|███████▊                          | 69/300 [01:09<03:51,  1.00s/it]

Epoch:  68
t_loss:  0.423418847953572 , v_loss:  0.5695409129063288
t_acc:  0.7005913476501712 , v_acc:  0.7080745341614907
t_recall:  0.5130680336531203 , v_recall:  0.5354954954954955
t_prec:  0.6806165033196332 , v_prec:  0.7525641025641026
t_f:  0.44299409062010286 , v_f:  0.4847122914538645
////////


Iterations:  23%|███████▉                          | 70/300 [01:10<03:41,  1.04it/s]

Epoch  69 , loss 0.42781779286908167
Epoch  70 , loss 0.42507968638457505


Iterations:  24%|████████                          | 71/300 [01:11<03:59,  1.05s/it]

Epoch:  70
t_loss:  0.42507968638457505 , v_loss:  0.5692323644955953
t_acc:  0.7009025832555245 , v_acc:  0.7080745341614907
t_recall:  0.5121416121107314 , v_recall:  0.5354954954954955
t_prec:  0.7040683726509396 , v_prec:  0.7525641025641026
t_f:  0.4395929893517526 , v_f:  0.4847122914538645
////////


In [None]:
userfold_results_summary(participants_dictionary, participants)
userfold_classwise_results_summary(participants_dictionary, participants)


In [None]:
start=-0.5
step=1.5/(188-1)
timestep_labels=[]
for i in range(input_dim[1]):
    timestep_labels.append(round(start+step*i,3))
    
# timestep_labels

In [None]:
# channel_names=[i for i in range(input_dim[0])]
channel_names=["AFz","F3","F1","Fz","F2","F4","FC5","FC3","FC1","FCz","FC2",
               "FC4","FC6","C5","C3","C1","Cz","C2","C4","C6","CP5","CP3",
               "CP1","CPz","CP2","CP4","CP6","P3","P1","Pz","P2","P4"]

In [None]:
from sklearn.preprocessing import MinMaxScaler

participants_w_list=[]

for i in range(len(participants)):

    w= pickle.load(
        open(f"{saved_dir}/Userfold-{participants[i]}-EEGNet-Weight_Multivariate-w-e{EPOCH}.pkl", "rb") 
                    )  
    participants_w_list.append(w)
    
avg_w= np.array(participants_w_list).mean(axis=0)
# scaler= MinMaxScaler()
# scaled_avg_w= scaler.fit_transform(avg_w)

In [None]:
# fig, ax= plt.subplots(3,1)
fig =plt.figure()
sns.heatmap(pd.DataFrame(avg_w), 
            xticklabels=timestep_labels,
            yticklabels=channel_names, 
            annot=False, cbar_kws={"pad":0.02})

# ax[0].set_title("AUC")

# sns.heatmap(pd.DataFrame(avg_grad["ROC"]), 
#             xticklabels=timestep_labels,
#             yticklabels=channel_names, ax=ax[1],
#             annot=False, cbar_kws={"pad":0.02})
# ax[1].set_title("ROC")

# sns.heatmap(pd.DataFrame(avg_grad["STD"]), 
#             xticklabels=timestep_labels,
#             yticklabels=channel_names, ax=ax[2],
#             annot=False, cbar_kws={"pad":0.02})
# ax[2].set_title("STD")

# plt.suptitle("Relative Importance Estimate of Time Series")
fig.set_figwidth(40)
fig.set_figheight(20)
plt.tight_layout()

In [None]:
import matplotlib.pyplot as plt
import matplotlib
from sklearn.preprocessing import MinMaxScaler

plt.figure(figsize=(8,6))


scaler= MinMaxScaler()
scaled_avg_w= scaler.fit_transform(avg_w.sum(0).reshape(-1,1))
df= pd.DataFrame(scaled_avg_w)
# df.index= channel_names
# df.columns=timestep_labels


plt.plot(df)
# plt.legend(methods)
plt.xlabel("Timesteps")
plt.axvspan(107,121, color="green", alpha=0.2)
plt.axvspan(112,128, color="red", alpha=0.2)
plt.axvspan(142,150, color="red", alpha=0.2)
plt.axvspan(157,159, color="red", alpha=0.2)
plt.xticks([0,31,62,93,124,155,187],[-0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
plt.ylabel("Normalised Importance")
plt.margins(x=0)
matplotlib.rcParams.update({"font.size":18})
plt.tight_layout()
#     sns.heatmap(df.sum().to_numpy().reshape(-1,1),annot=True, 
#                 yticklabels=timestep_labels, ax=ax[i][1],
#                 xticklabels=False, cbar_kws={"pad":0.02})

In [None]:
from sklearn.preprocessing import MinMaxScaler
import mne
info= mne.create_info(channel_names, sfreq=500, ch_types=32*["eeg"])
info.set_montage("standard_1020")

fig= plt.figure()
ax= plt.axes((0,0,1.5,1.5))

scaler= MinMaxScaler()
scaled_avg_w= scaler.fit_transform(avg_w.sum(1).reshape(-1,1))
# df= pd.DataFrame(scaled_avg_w.reshape(-1))

im, _= mne.viz.plot_topomap(
    scaled_avg_w.reshape(-1),
    info,
    ch_type= "eeg",
    sensors=True,
    names=channel_names,
    cmap="Blues",
    axes=ax,
    show=False,
    extrapolate="local"
#     sphere="eeglab"
)
fig.add_axes(ax)
cbar_ax= fig.add_axes([1.3,0.2, 0.1,1])
clb= fig.colorbar(im, cax=cbar_ax)
clb.set_label("Importance Estimate", rotation=270,labelpad=20)

for tt in plt.findobj(fig, matplotlib.text.Text):
    if tt.get_text() in channel_names:
        tt.set_fontsize(14)