## Import all library

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import mne
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pywt 
from PIL import Image
from utility import *
from model import *
from torchsummary import summary
import torch.optim.lr_scheduler as lr_scheduler

# total_sample = 3628
# num_spike    = 903
# ((total_sample-num_spike)*2)/(num_spike*2) = 3.032
LOSS_POS_WEIGHT       = torch.tensor([3.032])


MODEL_FILE_DIRC_SateLight = MODEL_FILE_DIRC + "/SateLight"
MODEL_FILE_DIRC_CNN       = MODEL_FILE_DIRC + "/CNN"
os.makedirs(MODEL_FILE_DIRC_CNN, exist_ok=True)
os.makedirs(MODEL_FILE_DIRC_SateLight, exist_ok=True)

torch.manual_seed(3407)

<torch._C.Generator at 0x24744b59d90>

In [1]:
total_sample = 3628
num_spike    = 903
((total_sample-num_spike)*2)/(num_spike*2) 

3.017718715393134

## Data Preparation
* Convert the data from csv to Dataloader

In [2]:
%%time
eeg_num_list = list(range(1,21))

dataloaders, num_data = get_dataloader(eeg_num_list, shuffle=True, get_dataDWT=False)
    

num_train_data, num_valid_data, num_test_data = num_data
train_data, valid_data, test_data             = dataloaders

The data from EEG_csv/eeg1.csv is loaded 
There is spike in this eeg file
Data before split : (3840, 19)
Data with   spike: (3840, 19)
Data after  split into window: (3, 1280, 19)
Labels: (3,)
Num spike: 3
EEG1 has 3 windows of data 


The data from EEG_csv/eeg2.csv is loaded 
There is no spike in this eeg file
(160, 1280, 19)
EEG2 has 160 windows of data 


The data from EEG_csv/eeg3.csv is loaded 
There is no spike in this eeg file
(153, 1280, 19)
EEG3 has 153 windows of data 


The data from EEG_csv/eeg4.csv is loaded 
There is no spike in this eeg file
(170, 1280, 19)
EEG4 has 170 windows of data 


The data from EEG_csv/eeg5.csv is loaded 
There is no spike in this eeg file
(148, 1280, 19)
EEG5 has 148 windows of data 


The data from EEG_csv/eeg6.csv is loaded 
There is no spike in this eeg file
(207, 1280, 19)
EEG6 has 207 windows of data 


The data from EEG_csv/eeg7.csv is loaded 
There is no spike in this eeg file
(147, 1280, 19)
EEG7 has 147 windows of data 


The data from 

## Function to train the model

In [3]:
def start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                        model, MODEL_FILE_DIRC, model_name,  
                                        df, prev_best_valid_f1,prev_best_valid_loss,
                                        train_data, num_train_data, 
                                        valid_data, num_valid_data, 
                                        scheduler, optimizer, device, LOSS_POS_WEIGHT):
    count = 0
    for epoch in range(EPOCH_START, NUM_EPOCHS_CLASSIFIER):
            ## 1. Training
            model.train()
            train_loss, train_metric = train_classifier(model, train_data, device, num_train_data, optimizer, LOSS_POS_WEIGHT=LOSS_POS_WEIGHT) 
            
            ## 2. Evaluating
            model.eval()
            valid_loss, valid_metric = evaluate_classifier(model, valid_data, device, num_valid_data, LOSS_POS_WEIGHT=LOSS_POS_WEIGHT) 
            
            ## 3. Show the result
            list_data       = [train_loss, valid_loss]
            for key in ["precision", "accuracy", "f1_score", "recall"]:
                list_data.append(train_metric[key])
                list_data.append(valid_metric[key])
            df.loc[len(df)] = list_data
            
            print_log(f"> > > Epoch     : {epoch}", MODEL_FILE_DIRC)
            print_log(f"Train {'loss':<10}: {train_loss}", MODEL_FILE_DIRC)
            print_log(f"Valid {'loss':<10}: {valid_loss}", MODEL_FILE_DIRC)
            for key in ["precision", "accuracy", "f1_score", "recall"]:
                print_log(f"Train {key:<10}: {train_metric[key]}", MODEL_FILE_DIRC)
                print_log(f"Valid {key:<10}: {valid_metric[key]}", MODEL_FILE_DIRC)

            
            ## 3.1 Plot the loss function
            fig,ax = plt.subplots(3,2, figsize=(10,10))
            x_data = range(len(df["Train Loss"]))
            for i, key in enumerate(["Loss","precision", "accuracy", "f1_score", "recall"]):    
                ax[i%3][i//3].plot(x_data, df[f"Train {key}"], label=f"Train {key}")
                ax[i%3][i//3].plot(x_data, df[f"Valid {key}"], label=f"Valid {key}")
                ax[i%3][i//3].legend()
            plt.savefig(f'{MODEL_FILE_DIRC}/Loss.png', transparent=False, facecolor='white')
            plt.close('all')

            ## 3.3. Save model and Stoping criteria
            if prev_best_valid_f1 <= valid_metric["f1_score"]:  # If previous best validation f1-score <= current f1-score 
                state_dict = {
                    "model": model.state_dict(), 
                    "epoch":epoch,
                    "valid_f1_score": valid_metric["f1_score"],
                    "valid_loss": valid_loss
                }
                torch.save(state_dict, f"{MODEL_FILE_DIRC}/{model_name}_best.pt")
                prev_best_valid_f1 = valid_metric["f1_score"]  # Previous validation loss = Current validation loss
                count = 0
            else:
                count += 1
            
            if epoch % 5 == 0:
                state_dict = {
                    "model": model.state_dict(), 
                    "epoch":epoch,
                    "valid_f1_score": valid_metric["f1_score"],
                    "valid_loss": valid_loss
                }
                torch.save(state_dict, f"{MODEL_FILE_DIRC}/{model_name}_{epoch}.pt")
            
            df.to_csv(f"{MODEL_FILE_DIRC}/Loss.csv", index=False)
            
            if count == MAX_COUNT_F1_SCORE:
                print_log(f"The validation f1 score is not increasing for continuous {MAX_COUNT_F1_SCORE} time, so training stop", MODEL_FILE_DIRC)
                break
            
            scheduler.step()

In [4]:
def load_classification_model_dict(model, MODEL_FILE_DIRC, model_name):
    list_model = os.listdir(MODEL_FILE_DIRC) 
    if len(list_model) > 0:    # Load the latest trained model
        if os.path.exists(f"{MODEL_FILE_DIRC}/{model_name}_best.pt"):
            state_dict_loaded    = torch.load(f"{MODEL_FILE_DIRC}/{model_name}_best.pt")
            prev_best_valid_f1   = state_dict_loaded["valid_f1_score"]
            prev_best_valid_loss = state_dict_loaded["valid_loss"]
        list_model.remove(f"{model_name}_best.pt")
        num_list   = [int(model_dir[model_dir.rindex("_") +1: model_dir.rindex(".")]) for model_dir in list_model if model_dir.endswith(".pt")]
        num_max    = np.max(num_list)
        
        state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC}/{model_name}_{num_max}.pt")
        model.load_state_dict(state_dict_loaded["model"])
        EPOCH_START = state_dict_loaded["epoch"] + 1
        
        print(f"The model has been loaded from the file '{model_name}_{num_max}.pt'")

        if os.path.exists(f"{MODEL_FILE_DIRC}/Loss.csv"):
            df = pd.read_csv(f"{MODEL_FILE_DIRC}/Loss.csv")
            df = df.iloc[:EPOCH_START-1, :]
            print(f"The dataframe that record the loss have been loaded from {MODEL_FILE_DIRC}/Loss.csv")

    else:
        EPOCH_START            = 1
        prev_best_valid_f1     = -1
        prev_best_valid_loss   = 10000
        df                     = pd.DataFrame(columns = ["Train Loss", "Valid Loss"] + \
                                                         flatten_concatenation([[f"Train {metric}", f"Valid {metric}"] for metric in ["precision", "accuracy", "f1_score", "recall"]]) )
    return model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss

## Build and Train the SateLight model

### Build the SateLight model

In [5]:
model      = SateLight().to(device)
optimizer  = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Get the following information:
# 1. Previous Trained model (if exist)
# 2. df that store the training/validation loss & metrics
# 3. epoch where the training start
# 4. Previous Highest Validation F1-score
# 5. Previous Lowest  Validation Loss
model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss = load_classification_model_dict(model, MODEL_FILE_DIRC_SateLight, "SateLight")

# Get the summary of the model
print(summary(model, (19,1280)))

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 32, 1, 641]          --
|    └─Conv2d: 2-1                            [-1, 16, 19, 641]         10,256
|    └─Conv2d: 2-2                            [-1, 32, 1, 641]          640
├─BatchNorm1d: 1-2                            [-1, 32, 641]             64
├─ReLU: 1-3                                   [-1, 32, 641]             --
├─Dropout: 1-4                                [-1, 32, 641]             --
├─MaxPool1d: 1-5                              [-1, 32, 160]             --
├─ModuleList: 1                               []                        --
|    └─Sequential: 2-3                        [-1, 160, 32]             --
|    |    └─SelfAttention: 3-1                [-1, 160, 32]             7,392
|    |    └─BatchNorm1d: 3-2                  [-1, 160, 32]             320
|    |    └─Dropout: 3-3                      [-1, 160, 32]             --
├─ModuleLis

### Print out the model info

In [6]:
seperate = "\n" + "-" * 100 + "\n"
print(seperate + "Model infomation" + seperate)
print(f"Device used        :", device)
print(f"BATCH SIZE         :", BATCH_SIZE)
print(f"MAX_COUNT_F1_SCORE :", MAX_COUNT_F1_SCORE)
print(f"LEARNING RATE      :", LEARNING_RATE)
print(f"Prev Best f1-score in validation dataset:", prev_best_valid_f1)
print(f"Prev Best validation loss               :", prev_best_valid_loss)
print(f"Number of EPOCH for training   :",NUM_EPOCHS_CLASSIFIER, f"(EPOCH start from {EPOCH_START})")
print(f"Num of epochs of data for train:", num_train_data)
print(f"Num of epochs of data for valid:", num_valid_data)
print(f'Model parameters               : {sum(p.numel() for p in model.parameters()):,}' )


----------------------------------------------------------------------------------------------------
Model infomation
----------------------------------------------------------------------------------------------------

Device used        : cuda
BATCH SIZE         : 32
MAX_COUNT_F1_SCORE : 10
LEARNING RATE      : 0.001
Prev Best f1-score in validation dataset: -1
Prev Best validation loss               : 10000
Number of EPOCH for training   : 101 (EPOCH start from 1)
Num of epochs of data for train: 2542
Num of epochs of data for valid: 543
Model parameters               : 134,521


### Start Training Loop

In [7]:
%%time
start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                    model, MODEL_FILE_DIRC_SateLight, "SateLight",
                                    df, prev_best_valid_f1, prev_best_valid_loss, 
                                    train_data, num_train_data, 
                                    valid_data, num_valid_data, 
                                    scheduler, optimizer, device, LOSS_POS_WEIGHT)

> > > Epoch     : 1
Train loss      : 0.016024461286328402
Valid loss      : 0.012869781325699875
Train precision : 0.7213541666666666
Valid precision : 0.8037974683544303
Train accuracy  : 0.8839496459480723
Valid accuracy  : 0.9300184162062615
Train f1_score  : 0.789736279401283
Valid f1_score  : 0.8698630136986302
Train recall    : 0.8724409448818897
Valid recall    : 0.9477611940298507
> > > Epoch     : 2
Train loss      : 0.003605944743713949
Valid loss      : 0.0036866476130714677
Train precision : 0.9523809523809523
Valid precision : 0.9705882352941176
Train accuracy  : 0.981904012588513
Valid accuracy  : 0.988950276243094
Train f1_score  : 0.9642301710730948
Valid f1_score  : 0.9777777777777777
Train recall    : 0.9763779527559056
Valid recall    : 0.9850746268656716
> > > Epoch     : 3
Train loss      : 0.0009724713005308481
Valid loss      : 0.00729109709383491
Train precision : 0.9859375
Valid precision : 0.9767441860465116
Train accuracy  : 0.9948859166011015
Valid accuracy

### Test the model performance on testing dataset

In [8]:
# Load the best model and turn to evaluation mode
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_SateLight}/SateLight_best.pt")
model.load_state_dict(state_dict_loaded["model"])

test_loss, test_metric = evaluate_classifier(model, test_data, device, num_test_data) 

print("Best model is at epoch:",state_dict_loaded["epoch"])
print("Metric on testing dataset:")
for key, value in test_metric.items():
    print(f"{key:<10}: {value:.4f}")

Best model is at epoch: 24
Metric on testing dataset:
precision : 0.7043
accuracy  : 0.8932
f1_score  : 0.8187
recall    : 0.9776


## Build and Train the Simple CNN model

### Build the CNN model

In [5]:
model      = CNN().to(device)
optimizer  = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  
scheduler  = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Get the following information:
# 1. Previous Trained model (if exist)
# 2. df that store the training/validation loss & metrics
# 3. epoch where the training start
# 4. Previous Highest Validation F1-score
# 5. Previous Lowest  Validation Loss
model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss = load_classification_model_dict(model, MODEL_FILE_DIRC_CNN, "CNN")

# Get the summary of the model
print(summary(model, (19,1280)))

The model has been loaded from the file 'CNN_20.pt'
The dataframe that record the loss have been loaded from Model/CNN/Loss.csv
Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 128, 8]              --
|    └─Conv1dWithInitialization: 2-1          [-1, 128, 1278]           --
|    |    └─Conv1d: 3-1                       [-1, 128, 1278]           7,424
|    └─BatchNorm1d: 2-2                       [-1, 128, 1278]           256
|    └─ReLU: 2-3                              [-1, 128, 1278]           --
|    └─MaxPool1d: 2-4                         [-1, 128, 639]            --
|    └─Conv1dWithInitialization: 2-5          [-1, 64, 637]             --
|    |    └─Conv1d: 3-2                       [-1, 64, 637]             24,640
|    └─BatchNorm1d: 2-6                       [-1, 64, 637]             128
|    └─ReLU: 2-7                              [-1, 64, 637]             --
|    └─MaxPool1d: 2-8            

### Print out the model info

In [10]:
seperate = "\n" + "-" * 100 + "\n"
print(seperate + "Model infomation" + seperate)
print(f"Device used        :", device)
print(f"BATCH SIZE         :", BATCH_SIZE)
print(f"MAX_COUNT_F1_SCORE :", MAX_COUNT_F1_SCORE)
print(f"LEARNING RATE      :", LEARNING_RATE)
print(f"Prev Best recall in validation dataset:", prev_best_valid_f1)
print(f"Prev Best validation loss             :", prev_best_valid_loss)
print(f"Number of EPOCH for training   :",NUM_EPOCHS_CLASSIFIER, f"(EPOCH start from {EPOCH_START})")
print(f"Num of epochs of data for train:", num_train_data)
print(f"Num of epochs of data for valid:", num_valid_data)
print(f'Model parameters               : {sum(p.numel() for p in model.parameters()):,}' )


----------------------------------------------------------------------------------------------------
Model infomation
----------------------------------------------------------------------------------------------------

Device used        : cuda
BATCH SIZE         : 32
MAX_COUNT_F1_SCORE : 10
LEARNING RATE      : 0.001
Prev Best recall in validation dataset: -1
Prev Best validation loss             : 10000
Number of EPOCH for training   : 101 (EPOCH start from 1)
Num of epochs of data for train: 2542
Num of epochs of data for valid: 543
Model parameters               : 74,225


### Start Training Loop

In [11]:
%%time
start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                    model, MODEL_FILE_DIRC_CNN, "CNN",
                                    df, prev_best_valid_f1, prev_best_valid_loss, 
                                    train_data, num_train_data, 
                                    valid_data, num_valid_data, 
                                    scheduler, optimizer, device, LOSS_POS_WEIGHT)

> > > Epoch     : 1
Train loss      : 0.007718846433994563
Valid loss      : 0.0042417199573004495
Train precision : 0.8393113342898135
Valid precision : 0.900709219858156
Train accuracy  : 0.9362706530291109
Valid accuracy  : 0.9613259668508287
Train f1_score  : 0.8783783783783784
Valid f1_score  : 0.9236363636363636
Train recall    : 0.9212598425196851
Valid recall    : 0.9477611940298507
> > > Epoch     : 2
Train loss      : 0.0007576334199437287
Valid loss      : 0.0013085228423579402
Train precision : 0.9829192546583851
Valid precision : 0.9568345323741008
Train accuracy  : 0.9948859166011015
Valid accuracy  : 0.9871086556169429
Train f1_score  : 0.9898358092259578
Valid f1_score  : 0.9743589743589743
Train recall    : 0.9968503937007874
Valid recall    : 0.9925373134328358
> > > Epoch     : 3
Train loss      : 0.0006958726925133132
Valid loss      : 0.0047371883114160325
Train precision : 0.9936808846761453
Valid precision : 1.0
Train accuracy  : 0.996066089693155
Valid accuracy 

### Test the model performance on testing dataset

In [12]:
# Load the best model and turn to evaluation mode
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_CNN}/CNN_best.pt")
model.load_state_dict(state_dict_loaded["model"])

test_loss, test_metric = evaluate_classifier(model, test_data, device, num_test_data) 

print("Best model is at epoch:",state_dict_loaded["epoch"])
print("Metric on testing dataset:")
for key, value in test_metric.items():
    print(f"{key:<10}: {value:.4f}")

Best model is at epoch: 13
Metric on testing dataset:
precision : 0.9416
accuracy  : 0.9761
f1_score  : 0.9520
recall    : 0.9627
