In [1]:
import numpy as np
import pickle
import sys
from typing import List
import torch
import os

In [2]:
participant_data_path = './EEG/number'
processed_data_path = './EEG/processed_data'
split_data_path = './EEG/split_data'

In [3]:
ps = os.listdir(participant_data_path)
participants=[]
for p in ps:
    p= p[:3]
    if p not in participants:
        participants.append(p)
        
participants= sorted(participants)
# participants, len(participants)

In [4]:
load_name=f"userfold_data_scaled_p_dictionary-number"
data_dir = "./EEG/split_data/standard_scaled"
try:
    raw_user_fold= pickle.load(open(os.path.join(data_dir, f"{load_name}.pkl"), "rb"))
except:
    print(f"pickle file does not exist. Use EEG-Preprocess.ipynb and EEG-Split.ipynb to save data setting.")
    sys.exit()

In [None]:
from utilities.userfold_framework import *
from utilities.EEG_func import *
import Models.model_func as Model_Func
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List
from torcheeg.models import EEGNet
# import Models.model_func as Model_Func
from torch import nn
from Models.multi_models import *

DEVICE= torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

learning_rate = 0.00005
batch_size = 64
n_epochs = 300
transpose_channels=True
participants_dictionary=[]
# participants_online_dictionary=[]
participants_grads_dictionary={}
b_acc_list=[]
c0_acc_list=[]
c1_acc_list=[]

# EPOCH=[
    
# ]

for i in range(len(participants)):

    train_dataloader, val_dataloader, classes, input_dim, class_ratio= user_fold_load(i,
                                                                                      raw_user_fold,
                                                                                      participants,
                                                                                      batch_size=batch_size,
                                                                                      transpose_channels=transpose_channels)

    classifier= EEGNet(
        chunk_size=input_dim[1],
        num_electrodes=input_dim[0],
        num_classes=classes,
        kernel_1= 32,
        kernel_2=32,
        F1=8,
        F2=16,
        dropout=0.5
    ).to(DEVICE)
    
#     resnet = ResNetPlus(input_dim[0], classes, bn_1st=False)
#     softmax_activation = nn.LogSoftmax(dim=1)
#     classifier = nn.Sequential(resnet, softmax_activation).to(DEVICE)

    
#     classifier = DataGliderBasic_Model(DEVICE, input_dim, classes)
#     classifier.to(DEVICE)
    
#     optimizer= torch.optim.RMSprop(classifier.parameters(), lr=learning_rate)
    
    
    criterion= torch.nn.CrossEntropyLoss(weight=torch.tensor(class_ratio, dtype=torch.float).to(DEVICE))
#     criterion = nn.NLLLoss(weight=torch.tensor(class_ratio, dtype=torch.float).to(DEVICE))
        
    saved_dir= "./EEG/saved_models/Userfold/run0"
    model = EEGNet_IE_Wrapper(DEVICE, classifier, input_dim).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)

    
    train_func= eeg_train
    model.training_procedure(iteration=n_epochs,
                                    train_dataloader=train_dataloader,
                                     val_dataloader=val_dataloader,
                                     print_cycle=2,
                                     path=f"./dictionary/intermdiate_dicts",
                                     loss_func=criterion,
                                     optimiser=optimizer, #scheduler=scheduler,
                                     train_func=train_func
                                    )
    if model.epoch == n_epochs+1:
        EPOCH= n_epochs
    else:
        EPOCH= model.epoch
    
    torch.save(model.state_dict(), 
           os.path.join(
               saved_dir, f"Userfold-{participants[i]}-EEGNet-Weight_Multivariate-e{EPOCH}.pt"
           )
    )

    pickle.dump( model.return_IE_weights(), 
                open(f"{saved_dir}/Userfold-{participants[i]}-EEGNet-Weight_Multivariate-w-e{EPOCH}.pkl", "wb") 
               )    

# OR
#     model.load_state_dict(
#     torch.load(
#         open(
#             os.path.join(
#                 saved_dir, f"Userfold-{participants[i]}-EEGNet-Weight_Multivariate-e{n_epochs}.pt"
#             ), "rb"
#         )
#               )
#     )
#     pickle.load( 
#                 open(f"{saved_dir}/Userfold-{participants[i]}-EEGNet-Weight_Multivariate-w-e{EPOCH}.pkl", "rb") 
#                )  
    
    prediction, dictionary= model.prediction_procedure(val_dataloader, dict_flag=True)
    
    ys= np.concatenate([y.detach().cpu().numpy() for x, y in val_dataloader])
    
    c0_acc, c1_acc, b_acc= calculate_accuracy(ys, prediction)
    print("c0_acc", c0_acc, ", c1_acc", c1_acc, ", b_acc", b_acc)
    b_acc_list.append(b_acc)
    c0_acc_list.append(c0_acc)
    c1_acc_list.append(c1_acc)
    participants_dictionary.append(dictionary)
    

tmp=[]
for i, dictionary in enumerate(participants_dictionary):
    print(f"User {participants[i]} f1: {dictionary['weighted avg']['f1-score']} acc: {dictionary['accuracy']}")
    print(f" c0: {c0_acc_list[i]} c1: {c1_acc_list[i]} bacc: {b_acc_list[i]}")
    tmp.append(dictionary['weighted avg']['f1-score'])

print(f"average {np.mean(tmp)}")
print()
print(np.array(b_acc_list).mean())
print(np.array(c1_acc_list).mean())
print(np.array(c0_acc_list).mean())


Iterations:   0%|                                           | 0/300 [00:00<?, ?it/s]

Epoch  0 , loss 0.7521746041728001


Iterations:   0%|                                   | 1/300 [00:01<07:15,  1.46s/it]

Epoch:  0
t_loss:  0.7521746041728001 , v_loss:  0.6901159485181173
t_acc:  0.4596949891067538 , v_acc:  0.5838509316770186
t_recall:  0.5311236463951908 , v_recall:  0.5827927927927927
t_prec:  0.5298239522753636 , v_prec:  0.5712403100775194
t_f:  0.4593513807759144 , v_f:  0.5619492385786802
////////


Iterations:   1%|▏                                  | 2/300 [00:02<05:28,  1.10s/it]

Epoch  1 , loss 0.7207762061380872
Epoch  2 , loss 0.6909797717543209


Iterations:   1%|▎                                  | 3/300 [00:03<06:05,  1.23s/it]

Epoch:  2
t_loss:  0.6909797717543209 , v_loss:  0.6827052434285482
t_acc:  0.548708372237784 , v_acc:  0.639751552795031
t_recall:  0.5516881459216865 , v_recall:  0.6343243243243244
t_prec:  0.543804828004941 , v_prec:  0.6170329670329671
t_f:  0.5282563209005278 , v_f:  0.6147689768976898
////////


Iterations:   1%|▍                                  | 4/300 [00:04<05:22,  1.09s/it]

Epoch  3 , loss 0.6670086675999212
Epoch  4 , loss 0.6430516663719626


Iterations:   2%|▌                                  | 5/300 [00:05<05:51,  1.19s/it]

Epoch:  4
t_loss:  0.6430516663719626 , v_loss:  0.6662986278533936
t_acc:  0.6050420168067226 , v_acc:  0.7142857142857143
t_recall:  0.5378364282675231 , v_recall:  0.6389189189189189
t_prec:  0.5372298714078513 , v_prec:  0.6592975206611571
t_f:  0.5374794180855402 , v_f:  0.6453065134099618
////////


Iterations:   2%|▋                                  | 6/300 [00:06<05:10,  1.05s/it]

Epoch  5 , loss 0.6249698049881879
Epoch  6 , loss 0.6035522129021439


Iterations:   2%|▊                                  | 7/300 [00:07<05:26,  1.12s/it]

Epoch:  6
t_loss:  0.6035522129021439 , v_loss:  0.6464685996373495
t_acc:  0.6402116402116402 , v_acc:  0.7391304347826086
t_recall:  0.5182607979577551 , v_recall:  0.6074774774774775
t_prec:  0.5260670749626774 , v_prec:  0.7317404817404818
t_f:  0.5107037845015651 , v_f:  0.6084993052339046
////////


Iterations:   3%|▉                                  | 8/300 [00:08<04:51,  1.00it/s]

Epoch  7 , loss 0.5843391272367215
Epoch  8 , loss 0.5653581981565438


Iterations:   3%|█                                  | 9/300 [00:10<05:19,  1.10s/it]

Epoch:  8
t_loss:  0.5653581981565438 , v_loss:  0.6315245032310486
t_acc:  0.6728913787737317 , v_acc:  0.7267080745341615
t_recall:  0.5233485678209193 , v_recall:  0.5709909909909909
t_prec:  0.5510596916969474 , v_prec:  0.7609271523178809
t_f:  0.5009345479476227 , v_f:  0.5493638676844783
////////


Iterations:   3%|█▏                                | 10/300 [00:10<04:50,  1.00s/it]

Epoch  9 , loss 0.5539117303549075
Epoch  10 , loss 0.5410025634017646


Iterations:   4%|█▏                                | 11/300 [00:12<05:18,  1.10s/it]

Epoch:  10
t_loss:  0.5410025634017646 , v_loss:  0.6222745478153229
t_acc:  0.6834733893557423 , v_acc:  0.7018633540372671
t_recall:  0.5145654053608927 , v_recall:  0.5254954954954955
t_prec:  0.5525881548433134 , v_prec:  0.7253184713375795
t_f:  0.47178976863551253 , v_f:  0.4660033167495854
////////


Iterations:   4%|█▎                                | 12/300 [00:12<04:49,  1.01s/it]

Epoch  11 , loss 0.527734659466089
Epoch  12 , loss 0.5205557854736552


Iterations:   4%|█▍                                | 13/300 [00:14<05:11,  1.08s/it]

Epoch:  12
t_loss:  0.5205557854736552 , v_loss:  0.6175274848937988
t_acc:  0.6881419234360411 , v_acc:  0.7018633540372671
t_recall:  0.5101576975336599 , v_recall:  0.5254954954954955
t_prec:  0.554594140037178 , v_prec:  0.7253184713375795
t_f:  0.4544048563235923 , v_f:  0.4660033167495854
////////


Iterations:   5%|█▌                                | 14/300 [00:14<04:42,  1.01it/s]

Epoch  13 , loss 0.5090562835627911
Epoch  14 , loss 0.4993336615609188


Iterations:   5%|█▋                                | 15/300 [00:16<05:14,  1.11s/it]

Epoch:  14
t_loss:  0.4993336615609188 , v_loss:  0.6148425539334615
t_acc:  0.6887643946467475 , v_acc:  0.7080745341614907
t_recall:  0.503991847490427 , v_recall:  0.53
t_prec:  0.5331975071907957 , v_prec:  0.8512658227848101
t_f:  0.4359761877677929 , v_f:  0.4692431787893666
////////


Iterations:   5%|█▊                                | 16/300 [00:17<04:49,  1.02s/it]

Epoch  15 , loss 0.48873810616194036
Epoch  16 , loss 0.47960536328016545


Iterations:   6%|█▉                                | 17/300 [00:18<05:10,  1.10s/it]

Epoch:  16
t_loss:  0.47960536328016545 , v_loss:  0.6136613090833029
t_acc:  0.6977902272019919 , v_acc:  0.7018633540372671
t_recall:  0.5078917390647946 , v_recall:  0.52
t_prec:  0.6467935189597658 , v_prec:  0.8490566037735849
t_f:  0.43195024184451847 , v_f:  0.44957264957264953
////////


Iterations:   6%|██                                | 18/300 [00:19<04:40,  1.00it/s]

Epoch  17 , loss 0.4810873573901607
Epoch  18 , loss 0.4713788406521666


Iterations:   6%|██▏                               | 19/300 [00:20<05:08,  1.10s/it]

Epoch:  18
t_loss:  0.4713788406521666 , v_loss:  0.6142550408840179
t_acc:  0.6962340491752257 , v_acc:  0.7018633540372671
t_recall:  0.5059105694404413 , v_recall:  0.52
t_prec:  0.6129603399433428 , v_prec:  0.8490566037735849
t_f:  0.4285676799748094 , v_f:  0.44957264957264953
////////


Iterations:   7%|██▎                               | 20/300 [00:21<04:42,  1.01s/it]

Epoch  19 , loss 0.47658727975452647
Epoch  20 , loss 0.4630153091514812


Iterations:   7%|██▍                               | 21/300 [00:22<05:10,  1.11s/it]

Epoch:  20
t_loss:  0.4630153091514812 , v_loss:  0.6152823567390442
t_acc:  0.6940553999377529 , v_acc:  0.7018633540372671
t_recall:  0.5014692359424109 , v_recall:  0.52
t_prec:  0.5437712961700968 , v_prec:  0.8490566037735849
t_f:  0.41839193064999514 , v_f:  0.44957264957264953
////////


Iterations:   7%|██▍                               | 22/300 [00:23<04:41,  1.01s/it]

Epoch  21 , loss 0.45874982896973104
Epoch  22 , loss 0.4577232675225127


Iterations:   8%|██▌                               | 23/300 [00:24<05:03,  1.10s/it]

Epoch:  22
t_loss:  0.4577232675225127 , v_loss:  0.6162187457084656
t_acc:  0.6949891067538126 , v_acc:  0.6894409937888198
t_recall:  0.5015653092875475 , v_recall:  0.5
t_prec:  0.5668888802001877 , v_prec:  0.3447204968944099
t_f:  0.41683606795411643 , v_f:  0.4080882352941176
////////


Iterations:   8%|██▋                               | 24/300 [00:25<04:40,  1.01s/it]

Epoch  23 , loss 0.46153216502245736
Epoch  24 , loss 0.4491976482026717


Iterations:   8%|██▊                               | 25/300 [00:26<04:56,  1.08s/it]

Epoch:  24
t_loss:  0.4491976482026717 , v_loss:  0.6171568433443705
t_acc:  0.6971677559912854 , v_acc:  0.6894409937888198
t_recall:  0.5031313048132746 , v_recall:  0.5
t_prec:  0.7373595505617978 , v_prec:  0.3447204968944099
t_f:  0.41764560913497084 , v_f:  0.4080882352941176
////////


Iterations:   9%|██▉                               | 26/300 [00:27<04:31,  1.01it/s]

Epoch  25 , loss 0.44541018616919426
Epoch  26 , loss 0.45534155999912934


Iterations:   9%|███                               | 27/300 [00:28<04:54,  1.08s/it]

Epoch:  26
t_loss:  0.45534155999912934 , v_loss:  0.6182516018549601
t_acc:  0.6962340491752257 , v_acc:  0.6894409937888198
t_recall:  0.5024601638736773 , v_recall:  0.5
t_prec:  0.6399953139643861 , v_prec:  0.3447204968944099
t_f:  0.4172988159743127 , v_f:  0.4080882352941176
////////


Iterations:   9%|███▏                              | 28/300 [00:29<04:36,  1.02s/it]

Epoch  27 , loss 0.44229792496737314
Epoch  28 , loss 0.452491913356033


Iterations:  10%|███▎                              | 29/300 [00:31<04:57,  1.10s/it]

Epoch:  28
t_loss:  0.452491913356033 , v_loss:  0.6182934542497
t_acc:  0.6959228135698724 , v_acc:  0.6894409937888198
t_recall:  0.5013738488354538 , v_recall:  0.5
t_prec:  0.6338116032439176 , v_prec:  0.3447204968944099
t_f:  0.4142796235015772 , v_f:  0.4080882352941176
////////


Iterations:  10%|███▍                              | 30/300 [00:31<04:39,  1.04s/it]

Epoch  29 , loss 0.446802207067901
Epoch  30 , loss 0.44052381725872264


In [None]:
userfold_results_summary(participants_dictionary, participants)
userfold_classwise_results_summary(participants_dictionary, participants)


In [None]:
model.return_IE_weights()