## Import all library

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import mne
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import pywt 
from PIL import Image
from utility import *
from model import *
from torchsummary import summary
import torch.optim.lr_scheduler as lr_scheduler


# total_sample = 6204
# num_spike    = 1535
# ((total_sample-num_spike)*2)/(num_spike*2) = 3.041
LOSS_POS_WEIGHT       = torch.tensor([3.041])

MODEL_FILE_DIRC_SateLight = MODEL_FILE_DIRC + "/SateLight_synthesized"
MODEL_FILE_DIRC_CNN       = MODEL_FILE_DIRC + "/CNN_synthesized"
os.makedirs(MODEL_FILE_DIRC_CNN, exist_ok=True)
os.makedirs(MODEL_FILE_DIRC_SateLight, exist_ok=True)

torch.manual_seed(3407)

<torch._C.Generator at 0x22fbe6aad50>

## Data Preparation
* Convert the data from csv to Dataloader

In [2]:
%%time
eeg_num_list = list(range(1,21))
datasets, datasets_label, datasets_DWT = get_dataloader(eeg_num_list, get_dataloader=False, shuffle=False)
    
train_dataset, valid_dataset, test_dataset = datasets
train_label,   valid_label,   test_label   = datasets_label
train_DWT,     valid_DWT,     test_DWT     = datasets_DWT

# Load the synthesized data and process it (copy from Evaluate Diffusion Model)
df                  = pd.read_csv("synthesized_data.csv")  # Shape=(num_data, num_channel)
synthesized_data = torch.tensor(df.values)
synthesized_data = synthesized_data.view(synthesized_data.shape[0]//(DURATION*NEW_FREQUENCY),-1 , len(AVE_CHANNELS_NAME)) # Shape=(bs, num_signal, num_channel)
synthesized_data = synthesized_data.permute(0,2,1)
synthesized_data = synthesized_data.type(torch.float32) 


# Get the label and data after discrete wavelet transform
synthesized_data_label = train_label[:len(synthesized_data)].clone()
synthesized_data_DWT   = pywt.wavedec(synthesized_data, 'db1')                         # Apply DWT
synthesized_data_DWT   = np.concatenate(synthesized_data_DWT, axis=2, dtype=np.float32)# Concatenate all DWT data
synthesized_data_DWT   = torch.from_numpy(synthesized_data_DWT)

# Concatenate synthesized data and training data
train_dataset = torch.cat([train_dataset, synthesized_data], axis=0) 
train_label   = torch.cat([train_label, synthesized_data_label], axis=0)
train_DWT     = torch.cat([train_DWT, synthesized_data_DWT], axis=0)

print("Number of training data:", train_dataset.shape[0])
print("Number of spike in training data:", (train_label==1).sum())


The data from EEG_csv/eeg1.csv is loaded 
There is spike in this eeg file
Data before split : (3840, 19)
Data with   spike: (3840, 19)
Data after  split into window: (3, 1280, 19)
Labels: (3,)
Num spike: 3
EEG1 has 3 windows of data 


The data from EEG_csv/eeg2.csv is loaded 
There is no spike in this eeg file
(160, 1280, 19)
EEG2 has 160 windows of data 


The data from EEG_csv/eeg3.csv is loaded 
There is no spike in this eeg file
(153, 1280, 19)
EEG3 has 153 windows of data 


The data from EEG_csv/eeg4.csv is loaded 
There is no spike in this eeg file
(170, 1280, 19)
EEG4 has 170 windows of data 


The data from EEG_csv/eeg5.csv is loaded 
There is no spike in this eeg file
(148, 1280, 19)
EEG5 has 148 windows of data 


The data from EEG_csv/eeg6.csv is loaded 
There is no spike in this eeg file
(207, 1280, 19)
EEG6 has 207 windows of data 


The data from EEG_csv/eeg7.csv is loaded 
There is no spike in this eeg file
(147, 1280, 19)
EEG7 has 147 windows of data 


The data from 

In [3]:
# Get the number of data and Place it into dataloader
num_train_data = train_dataset.shape[0]
num_valid_data = valid_dataset.shape[0]
num_test_data  = test_dataset.shape[0]
train_data = DataLoader(dataset = Dataset_Class1(train_dataset, train_label), 
                        batch_size = BATCH_SIZE, shuffle = True, num_workers=1)
valid_data = DataLoader(dataset = Dataset_Class1(valid_dataset, valid_label), 
                        batch_size = BATCH_SIZE, shuffle = True, num_workers=1)
test_data  = DataLoader(dataset = Dataset_Class1(test_dataset, test_label), 
                        batch_size = BATCH_SIZE, shuffle = True, num_workers=1)

## Function to train the model

In [4]:
def load_classification_model_dict(model, MODEL_FILE_DIRC, model_name):
    list_model = os.listdir(MODEL_FILE_DIRC) 
    if len(list_model) > 0:    # Load the latest trained model
        if os.path.exists(f"{MODEL_FILE_DIRC}/{model_name}_best.pt"):
            state_dict_loaded    = torch.load(f"{MODEL_FILE_DIRC}/{model_name}_best.pt")
            prev_best_valid_f1   = state_dict_loaded["valid_f1_score"]
            prev_best_valid_loss = state_dict_loaded["valid_loss"]
        list_model.remove(f"{model_name}_best.pt")
        num_list   = [int(model_dir[model_dir.rindex("_") +1: model_dir.rindex(".")]) for model_dir in list_model if model_dir.endswith(".pt")]
        num_max    = np.max(num_list)
        
        state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC}/{model_name}_{num_max}.pt")
        model.load_state_dict(state_dict_loaded["model"])
        EPOCH_START = state_dict_loaded["epoch"] + 1
        
        print(f"The model has been loaded from the file '{model_name}_{num_max}.pt'")

        if os.path.exists(f"{MODEL_FILE_DIRC}/Loss.csv"):
            df = pd.read_csv(f"{MODEL_FILE_DIRC}/Loss.csv")
            df = df.iloc[:EPOCH_START-1, :]
            print(f"The dataframe that record the loss have been loaded from {MODEL_FILE_DIRC}/Loss.csv")

    else:
        EPOCH_START            = 1
        prev_best_valid_f1     = -1
        prev_best_valid_loss   = 10000
        df                     = pd.DataFrame(columns = ["Train Loss", "Valid Loss"] + \
                                                         flatten_concatenation([[f"Train {metric}", f"Valid {metric}"] for metric in ["precision", "accuracy", "f1_score", "recall"]]) )
    return model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss

def start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                        model, MODEL_FILE_DIRC, model_name,  
                                        df, prev_best_valid_f1,prev_best_valid_loss,
                                        train_data, num_train_data, 
                                        valid_data, num_valid_data, 
                                        scheduler, optimizer, device):
    count = 0
    for epoch in range(EPOCH_START, NUM_EPOCHS_CLASSIFIER):
        
            ## 1. Training
            model.train()
            train_loss, train_metric = train_classifier(model, train_data, device, num_train_data, optimizer, LOSS_POS_WEIGHT=torch.tensor([3])) 
            
            ## 2. Evaluating
            model.eval()
            valid_loss, valid_metric = evaluate_classifier(model, valid_data, device, num_valid_data, LOSS_POS_WEIGHT=torch.tensor([3]))  
            
            ## 3. Show the result
            list_data       = [train_loss, valid_loss]
            for key in ["precision", "accuracy", "f1_score", "recall"]:
                list_data.append(train_metric[key])
                list_data.append(valid_metric[key])
            df.loc[len(df)] = list_data
            
            print_log(f"> > > Epoch     : {epoch}", MODEL_FILE_DIRC)
            print_log(f"Train {'loss':<10}: {train_loss}", MODEL_FILE_DIRC)
            print_log(f"Valid {'loss':<10}: {valid_loss}", MODEL_FILE_DIRC)
            for key in ["precision", "accuracy", "f1_score", "recall"]:
                print_log(f"Train {key:<10}: {train_metric[key]}", MODEL_FILE_DIRC)
                print_log(f"Valid {key:<10}: {valid_metric[key]}", MODEL_FILE_DIRC)

            
            ## 3.1 Plot the loss function
            fig,ax = plt.subplots(3,2, figsize=(10,10))
            x_data = range(len(df["Train Loss"]))
            for i, key in enumerate(["Loss","precision", "accuracy", "f1_score", "recall"]):    
                ax[i%3][i//3].plot(x_data, df[f"Train {key}"], label=f"Train {key}")
                ax[i%3][i//3].plot(x_data, df[f"Valid {key}"], label=f"Valid {key}")
                ax[i%3][i//3].legend()
            plt.savefig(f'{MODEL_FILE_DIRC}/Loss.png', transparent=False, facecolor='white')
            plt.close('all')

            ## 3.3. Save model and Stoping criteria
            if prev_best_valid_f1 <= valid_metric["f1_score"]:  # If previous best validation f1-score <= current f1-score 
                state_dict = {
                    "model": model.state_dict(), 
                    "epoch":epoch,
                    "valid_f1_score": valid_metric["f1_score"],
                    "valid_loss": valid_loss
                }
                torch.save(state_dict, f"{MODEL_FILE_DIRC}/{model_name}_best.pt")
                prev_best_valid_f1 = valid_metric["f1_score"]  # Previous validation loss = Current validation loss
                count = 0
            else:
                count += 1
            
            if epoch % 5 == 0:
                state_dict = {
                    "model": model.state_dict(), 
                    "epoch":epoch,
                    "valid_f1_score": valid_metric["f1_score"],
                    "valid_loss": valid_loss
                }
                torch.save(state_dict, f"{MODEL_FILE_DIRC}/{model_name}_{epoch}.pt")
            
            df.to_csv(f"{MODEL_FILE_DIRC}/Loss.csv", index=False)
            
            if count == MAX_COUNT_F1_SCORE:
                print_log(f"The validation f1 score is not increasing for continuous {MAX_COUNT_F1_SCORE} time, so training stop", MODEL_FILE_DIRC)
                break
            
            scheduler.step()

## Build and Train the SateLight model

### Build the SateLight model

In [6]:
model      = SateLight().to(device)
optimizer  = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Get the following information:
# 1. Previous Trained model (if exist)
# 2. df that store the training/validation loss & metrics
# 3. epoch where the training start
# 4. Previous Highest Validation Recall
# 5. Previous Lowest  Validation Loss
model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss = load_classification_model_dict(model, MODEL_FILE_DIRC_SateLight, "SateLight")

# # Load the previous train model from original data
# state_dict_loaded = torch.load(MODEL_FILE_DIRC + f"/SateLight/SateLight_best.pt")
# model.load_state_dict(state_dict_loaded["model"])

# Get the summary of the model
print(summary(model, (19,1280)))

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 32, 1, 641]          --
|    └─Conv2d: 2-1                            [-1, 16, 19, 641]         10,256
|    └─Conv2d: 2-2                            [-1, 32, 1, 641]          640
├─BatchNorm1d: 1-2                            [-1, 32, 641]             64
├─ReLU: 1-3                                   [-1, 32, 641]             --
├─Dropout: 1-4                                [-1, 32, 641]             --
├─MaxPool1d: 1-5                              [-1, 32, 160]             --
├─ModuleList: 1                               []                        --
|    └─Sequential: 2-3                        [-1, 160, 32]             --
|    |    └─SelfAttention: 3-1                [-1, 160, 32]             7,392
|    |    └─BatchNorm1d: 3-2                  [-1, 160, 32]             320
|    |    └─Dropout: 3-3                      [-1, 160, 32]             --
├─ModuleLis

### Print out the model info

In [7]:
seperate = "\n" + "-" * 100 + "\n"
print(seperate + "Model infomation" + seperate)
print(f"Device used        :", device)
print(f"BATCH SIZE         :", BATCH_SIZE)
print(f"MAX_COUNT_F1_SCORE :", MAX_COUNT_F1_SCORE)
print(f"LEARNING RATE      :", LEARNING_RATE)
print(f"Prev Best f1-score in validation dataset:", prev_best_valid_f1)
print(f"Prev Best validation loss             :", prev_best_valid_loss)
print(f"Number of EPOCH for training   :",NUM_EPOCHS_CLASSIFIER, f"(EPOCH start from {EPOCH_START})")
print(f"Num of epochs of data for train:", num_train_data)
print(f"Num of epochs of data for valid:", num_valid_data)
print(f'Model parameters               : {sum(p.numel() for p in model.parameters()):,}' )


----------------------------------------------------------------------------------------------------
Model infomation
----------------------------------------------------------------------------------------------------

Device used        : cuda
BATCH SIZE         : 32
MAX_COUNT_F1_SCORE : 10
LEARNING RATE      : 0.001
Prev Best f1-score in validation dataset: -1
Prev Best validation loss             : 10000
Number of EPOCH for training   : 101 (EPOCH start from 1)
Num of epochs of data for train: 5152
Num of epochs of data for valid: 557
Model parameters               : 134,521


### Start Training Loop

In [8]:
%%time
start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                    model, MODEL_FILE_DIRC_SateLight, "SateLight",
                                    df, prev_best_valid_f1, prev_best_valid_loss, 
                                    train_data, num_train_data, 
                                    valid_data, num_valid_data, 
                                    scheduler, optimizer, device)

> > > Epoch     : 1
Train loss      : 0.022834234014847635
Valid loss      : 0.01330366495267707
Train precision : 0.5491803278688525
Valid precision : 0.9622641509433962
Train accuracy  : 0.7884316770186336
Valid accuracy  : 0.9353680430879713
Train f1_score  : 0.6483870967741936
Valid f1_score  : 0.85
Train recall    : 0.7913385826771654
Valid recall    : 0.7611940298507462
> > > Epoch     : 2
Train loss      : 0.012326737327836686
Valid loss      : 0.006221989072738901
Train precision : 0.8062636562272396
Valid precision : 0.9407407407407408
Train accuracy  : 0.9167313664596274
Valid accuracy  : 0.9730700179533214
Train f1_score  : 0.8376844494892168
Valid f1_score  : 0.9442379182156134
Train recall    : 0.8716535433070867
Valid recall    : 0.9477611940298507
> > > Epoch     : 3
Train loss      : 0.010174882674325108
Valid loss      : 0.00649497247697314
Train precision : 0.8554125662376987
Valid precision : 0.984375
Train accuracy  : 0.9357531055900621
Valid accuracy  : 0.982046678

### Test the model performance on testing dataset

In [9]:
# Load the best model and turn to evaluation mode
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_SateLight}/SateLight_best.pt")
model.load_state_dict(state_dict_loaded["model"])

test_loss, test_metric = evaluate_classifier(model, test_data, device, num_test_data, LOSS_POS_WEIGHT=torch.tensor([3])) 

print("Best model is at epoch:",state_dict_loaded["epoch"])
print("Metric on testing dataset:")
for key, value in test_metric.items():
    print(f"{key:<10}: {value:.4f}")

Best model is at epoch: 16
Metric on testing dataset:
precision : 0.9197
accuracy  : 0.9659
f1_score  : 0.9299
recall    : 0.9403


## Build and Train the Simple CNN model

### Build the CNN model

In [10]:
model      = CNN().to(device)
optimizer  = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  
scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

# Get the following information:
# 1. Previous Trained model (if exist)
# 2. df that store the training/validation loss & metrics
# 3. epoch where the training start
# 4. Previous Highest Validation Recall
# 5. Previous Lowest  Validation Loss
model, df, EPOCH_START, prev_best_valid_f1, prev_best_valid_loss = load_classification_model_dict(model, MODEL_FILE_DIRC_CNN, "CNN")

# # Load the previous train model from original data
# state_dict_loaded = torch.load(MODEL_FILE_DIRC + f"/CNN/CNN_best.pt")
# model.load_state_dict(state_dict_loaded["model"])

# Get the summary of the model
print(summary(model, (19,1280)))

Layer (type:depth-idx)                        Output Shape              Param #
├─Sequential: 1-1                             [-1, 128, 8]              --
|    └─Conv1dWithInitialization: 2-1          [-1, 128, 1278]           --
|    |    └─Conv1d: 3-1                       [-1, 128, 1278]           7,424
|    └─BatchNorm1d: 2-2                       [-1, 128, 1278]           256
|    └─ReLU: 2-3                              [-1, 128, 1278]           --
|    └─MaxPool1d: 2-4                         [-1, 128, 639]            --
|    └─Conv1dWithInitialization: 2-5          [-1, 64, 637]             --
|    |    └─Conv1d: 3-2                       [-1, 64, 637]             24,640
|    └─BatchNorm1d: 2-6                       [-1, 64, 637]             128
|    └─ReLU: 2-7                              [-1, 64, 637]             --
|    └─MaxPool1d: 2-8                         [-1, 64, 318]             --
|    └─Conv1dWithInitialization: 2-9          [-1, 32, 316]             --
|    |    └

### Print out the model info

In [11]:
seperate = "\n" + "-" * 100 + "\n"
print(seperate + "Model infomation" + seperate)
print(f"Device used        :", device)
print(f"BATCH SIZE         :", BATCH_SIZE)
print(f"MAX_COUNT_F1_SCORE :", MAX_COUNT_F1_SCORE)
print(f"LEARNING RATE      :", LEARNING_RATE)
print(f"Prev Best recall in validation dataset:", prev_best_valid_f1)
print(f"Prev Best validation loss             :", prev_best_valid_loss)
print(f"Number of EPOCH for training   :",NUM_EPOCHS_CLASSIFIER, f"(EPOCH start from {EPOCH_START})")
print(f"Num of epochs of data for train:", num_train_data)
print(f"Num of epochs of data for valid:", num_valid_data)
print(f'Model parameters               : {sum(p.numel() for p in model.parameters()):,}' )


----------------------------------------------------------------------------------------------------
Model infomation
----------------------------------------------------------------------------------------------------

Device used        : cuda
BATCH SIZE         : 32
MAX_COUNT_F1_SCORE : 10
LEARNING RATE      : 0.001
Prev Best recall in validation dataset: -1
Prev Best validation loss             : 10000
Number of EPOCH for training   : 101 (EPOCH start from 1)
Num of epochs of data for train: 5152
Num of epochs of data for valid: 557
Model parameters               : 74,225


### Start Training Loop

In [12]:
%%time
start_classification_model_training(EPOCH_START, NUM_EPOCHS_CLASSIFIER, 
                                    model, MODEL_FILE_DIRC_CNN, "CNN",
                                    df, prev_best_valid_f1, prev_best_valid_loss, 
                                    train_data, num_train_data, 
                                    valid_data, num_valid_data, 
                                    scheduler, optimizer, device)

> > > Epoch     : 1
Train loss      : 0.02110473745971085
Valid loss      : 0.009032030481196392
Train precision : 0.5571677307022634
Valid precision : 0.7344632768361582
Train accuracy  : 0.7917313664596274
Valid accuracy  : 0.9084380610412927
Train f1_score  : 0.6414968259271634
Valid f1_score  : 0.8360128617363344
Train recall    : 0.7559055118110236
Valid recall    : 0.9701492537313433
> > > Epoch     : 2
Train loss      : 0.008710967061205287
Valid loss      : 0.0051763932717231705
Train precision : 0.8756674294431731
Valid precision : 0.9618320610687023
Train accuracy  : 0.9446816770186336
Valid accuracy  : 0.9766606822262118
Train f1_score  : 0.8895776830685781
Valid f1_score  : 0.9509433962264151
Train recall    : 0.9039370078740158
Valid recall    : 0.9402985074626866
> > > Epoch     : 3
Train loss      : 0.0059012600201939856
Valid loss      : 0.0029768781682988042
Train precision : 0.9083269671504965
Valid precision : 0.9692307692307692
Train accuracy  : 0.9609860248447205
V

### Test the model performance on testing dataset

In [13]:
# Load the best model and turn to evaluation mode
model.eval()
state_dict_loaded = torch.load(f"{MODEL_FILE_DIRC_CNN}/CNN_best.pt")
model.load_state_dict(state_dict_loaded["model"])

test_loss, test_metric = evaluate_classifier(model, test_data, device, num_test_data, LOSS_POS_WEIGHT=torch.tensor([3]))

print("Best model is at epoch:",state_dict_loaded["epoch"])
print("Metric on testing dataset:")
for key, value in test_metric.items():
    print(f"{key:<10}: {value:.4f}")

Best model is at epoch: 56
Metric on testing dataset:
precision : 1.0000
accuracy  : 0.9820
f1_score  : 0.9612
recall    : 0.9254
